forked from mindspore-Ecosystem/mindspore
!1398 Update the bert scripts according to rules of modelzoo
Merge pull request !1398 from chenhaozhe/update_bert_script
This commit is contained in:
commit
b46ad9a1bb
|
@ -308,7 +308,7 @@ def get_bprop_softmax(self):
|
||||||
axis = self.axis
|
axis = self.axis
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = mul(sub(dout, sum_func(mul(dout, out), axis)), out)
|
dx = mul(out, sub(dout, sum_func(mul(out, dout), axis)))
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
|
@ -16,12 +16,12 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
|
||||||
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
|
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
|
sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
|
||||||
```
|
```
|
||||||
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
|
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH
|
sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH
|
||||||
```
|
```
|
||||||
|
|
||||||
### Fine-Tuning
|
### Fine-Tuning
|
|
@ -19,8 +19,6 @@ Bert evaluation script.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from evaluation_config import cfg, bert_net_cfg
|
|
||||||
from utils import BertNER, BertCLS
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -28,9 +26,11 @@ import mindspore.dataset as de
|
||||||
import mindspore.dataset.transforms.c_transforms as C
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from CRF import postprocess
|
from src.evaluation_config import cfg, bert_net_cfg
|
||||||
from cluener_evaluation import submit
|
from src.utils import BertNER, BertCLS
|
||||||
from finetune_config import tag_to_index
|
from src.CRF import postprocess
|
||||||
|
from src.cluener_evaluation import submit
|
||||||
|
from src.finetune_config import tag_to_index
|
||||||
|
|
||||||
class Accuracy():
|
class Accuracy():
|
||||||
'''
|
'''
|
|
@ -18,8 +18,8 @@ Bert finetune script.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from utils import BertFinetuneCell, BertCLS, BertNER
|
from src.utils import BertFinetuneCell, BertCLS, BertNER
|
||||||
from finetune_config import cfg, bert_net_cfg, tag_to_index
|
from src.finetune_config import cfg, bert_net_cfg, tag_to_index
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.communication.management as D
|
import mindspore.communication.management as D
|
||||||
from mindspore import context
|
from mindspore import context
|
|
@ -26,10 +26,10 @@ from mindspore.train.parallel_utils import ParallelMode
|
||||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||||
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
|
||||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
|
||||||
from dataset import create_bert_dataset
|
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||||
from config import cfg, bert_net_cfg
|
from src.dataset import create_bert_dataset
|
||||||
|
from src.config import cfg, bert_net_cfg
|
||||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
class LossCallBack(Callback):
|
class LossCallBack(Callback):
|
||||||
|
@ -48,10 +48,8 @@ class LossCallBack(Callback):
|
||||||
self._per_print_times = per_print_times
|
self._per_print_times = per_print_times
|
||||||
def step_end(self, run_context):
|
def step_end(self, run_context):
|
||||||
cb_params = run_context.original_args()
|
cb_params = run_context.original_args()
|
||||||
with open("./loss.log", "a+") as f:
|
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||||
f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
str(cb_params.net_outputs)))
|
||||||
str(cb_params.net_outputs)))
|
|
||||||
f.write('\n')
|
|
||||||
|
|
||||||
def run_pretrain():
|
def run_pretrain():
|
||||||
"""pre-train bert_clue"""
|
"""pre-train bert_clue"""
|
||||||
|
@ -81,6 +79,11 @@ def run_pretrain():
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||||
device_num=device_num)
|
device_num=device_num)
|
||||||
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
if bert_net_cfg.num_hidden_layers == 12:
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205])
|
||||||
|
elif bert_net_cfg.num_hidden_layers == 24:
|
||||||
|
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
|
||||||
D.init()
|
D.init()
|
||||||
rank = args_opt.device_id % device_num
|
rank = args_opt.device_id % device_num
|
||||||
else:
|
else:
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the scipt as: "
|
echo "Please run the scipt as: "
|
||||||
echo "sh run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH"
|
echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH"
|
||||||
echo "for example: sh run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json"
|
echo "for example: bash run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json /path/hccl.json"
|
||||||
echo "It is better to use absolute path."
|
echo "It is better to use absolute path."
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
@ -49,6 +49,10 @@ do
|
||||||
cp *.py ./LOG$i
|
cp *.py ./LOG$i
|
||||||
cd ./LOG$i || exit
|
cd ./LOG$i || exit
|
||||||
echo "start training for rank $i, device $DEVICE_ID"
|
echo "start training for rank $i, device $DEVICE_ID"
|
||||||
|
mkdir -p ms_log
|
||||||
|
CUR_DIR=`pwd`
|
||||||
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
|
export GLOG_logtostderr=0
|
||||||
env > env.log
|
env > env.log
|
||||||
taskset -c $cmdopt python ../run_pretrain.py \
|
taskset -c $cmdopt python ../run_pretrain.py \
|
||||||
--distribute="true" \
|
--distribute="true" \
|
||||||
|
@ -59,7 +63,7 @@ do
|
||||||
--enable_lossscale="true" \
|
--enable_lossscale="true" \
|
||||||
--do_shuffle="true" \
|
--do_shuffle="true" \
|
||||||
--enable_data_sink="true" \
|
--enable_data_sink="true" \
|
||||||
--data_sink_steps=1 \
|
--data_sink_steps=100 \
|
||||||
--checkpoint_path="" \
|
--checkpoint_path="" \
|
||||||
--save_checkpoint_steps=10000 \
|
--save_checkpoint_steps=10000 \
|
||||||
--save_checkpoint_num=1 \
|
--save_checkpoint_num=1 \
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the scipt as: "
|
echo "Please run the scipt as: "
|
||||||
echo "sh run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
|
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
|
||||||
echo "for example: sh run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
|
echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
DEVICE_ID=$1
|
DEVICE_ID=$1
|
||||||
|
@ -25,6 +25,10 @@ EPOCH_SIZE=$2
|
||||||
DATA_DIR=$3
|
DATA_DIR=$3
|
||||||
SCHEMA_DIR=$4
|
SCHEMA_DIR=$4
|
||||||
|
|
||||||
|
mkdir -p ms_log
|
||||||
|
CUR_DIR=`pwd`
|
||||||
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
|
export GLOG_logtostderr=0
|
||||||
python run_pretrain.py \
|
python run_pretrain.py \
|
||||||
--distribute="false" \
|
--distribute="false" \
|
||||||
--epoch_size=$EPOCH_SIZE \
|
--epoch_size=$EPOCH_SIZE \
|
||||||
|
@ -33,7 +37,7 @@ python run_pretrain.py \
|
||||||
--enable_lossscale="true" \
|
--enable_lossscale="true" \
|
||||||
--do_shuffle="true" \
|
--do_shuffle="true" \
|
||||||
--enable_data_sink="true" \
|
--enable_data_sink="true" \
|
||||||
--data_sink_steps=1 \
|
--data_sink_steps=100 \
|
||||||
--checkpoint_path="" \
|
--checkpoint_path="" \
|
||||||
--save_checkpoint_steps=10000 \
|
--save_checkpoint_steps=10000 \
|
||||||
--save_checkpoint_num=1 \
|
--save_checkpoint_num=1 \
|
|
@ -357,10 +357,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||||
self.reducer_flag = True
|
self.reducer_flag = True
|
||||||
self.grad_reducer = F.identity
|
self.grad_reducer = F.identity
|
||||||
|
self.degree = 1
|
||||||
if self.reducer_flag:
|
if self.reducer_flag:
|
||||||
mean = context.get_auto_parallel_context("mirror_mean")
|
self.degree = get_group_size()
|
||||||
degree = get_group_size()
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
||||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.alloc_status = P.NPUAllocFloatStatus()
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
|
@ -411,10 +411,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
masked_lm_weights,
|
masked_lm_weights,
|
||||||
self.cast(scaling_sens,
|
self.cast(scaling_sens,
|
||||||
mstype.float32))
|
mstype.float32))
|
||||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
|
||||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
|
||||||
# apply grad reducer on grads
|
# apply grad reducer on grads
|
||||||
grads = self.grad_reducer(grads)
|
grads = self.grad_reducer(grads)
|
||||||
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||||
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
self.get_status(init)
|
self.get_status(init)
|
||||||
flag_sum = self.reduce_sum(init, (0,))
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
if self.is_distributed:
|
if self.is_distributed:
|
|
@ -25,6 +25,7 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
|
from .fused_layer_norm import FusedLayerNorm
|
||||||
|
|
||||||
|
|
||||||
class BertConfig:
|
class BertConfig:
|
||||||
|
@ -77,7 +78,8 @@ class BertConfig:
|
||||||
input_mask_from_dataset=True,
|
input_mask_from_dataset=True,
|
||||||
token_type_ids_from_dataset=True,
|
token_type_ids_from_dataset=True,
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float32):
|
compute_type=mstype.float32,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
@ -96,6 +98,7 @@ class BertConfig:
|
||||||
self.use_relative_positions = use_relative_positions
|
self.use_relative_positions = use_relative_positions
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.compute_type = compute_type
|
self.compute_type = compute_type
|
||||||
|
self.enable_fused_layernorm = enable_fused_layernorm
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingLookup(nn.Cell):
|
class EmbeddingLookup(nn.Cell):
|
||||||
|
@ -240,13 +243,19 @@ class BertOutput(nn.Cell):
|
||||||
out_channels,
|
out_channels,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
dropout_prob=0.1,
|
dropout_prob=0.1,
|
||||||
compute_type=mstype.float32):
|
compute_type=mstype.float32,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
super(BertOutput, self).__init__()
|
super(BertOutput, self).__init__()
|
||||||
self.dense = nn.Dense(in_channels, out_channels,
|
self.dense = nn.Dense(in_channels, out_channels,
|
||||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||||
|
self.dropout_prob = dropout_prob
|
||||||
self.add = P.TensorAdd()
|
self.add = P.TensorAdd()
|
||||||
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
if compute_type == mstype.float16:
|
||||||
|
self.layernorm = FusedLayerNorm((out_channels,),
|
||||||
|
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
|
||||||
|
else:
|
||||||
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
def construct(self, hidden_status, input_tensor):
|
def construct(self, hidden_status, input_tensor):
|
||||||
|
@ -481,12 +490,13 @@ class BertAttention(nn.Cell):
|
||||||
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
|
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
|
||||||
|
|
||||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||||
self._generate_relative_positions_embeddings = \
|
if self.use_relative_positions:
|
||||||
RelaPosEmbeddingsGenerator(length=to_seq_length,
|
self._generate_relative_positions_embeddings = \
|
||||||
depth=size_per_head,
|
RelaPosEmbeddingsGenerator(length=to_seq_length,
|
||||||
max_relative_position=16,
|
depth=size_per_head,
|
||||||
initializer_range=initializer_range,
|
max_relative_position=16,
|
||||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
initializer_range=initializer_range,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||||
|
|
||||||
def construct(self, from_tensor, to_tensor, attention_mask):
|
def construct(self, from_tensor, to_tensor, attention_mask):
|
||||||
# reshape 2d/3d input tensors to 2d
|
# reshape 2d/3d input tensors to 2d
|
||||||
|
@ -529,7 +539,7 @@ class BertAttention(nn.Cell):
|
||||||
self.trans_shape_position)
|
self.trans_shape_position)
|
||||||
attention_scores = attention_scores + key_position_scores_r_t
|
attention_scores = attention_scores + key_position_scores_r_t
|
||||||
|
|
||||||
attention_scores = self.multiply(attention_scores, self.scores_mul)
|
attention_scores = self.multiply(self.scores_mul, attention_scores)
|
||||||
|
|
||||||
if self.has_attention_mask:
|
if self.has_attention_mask:
|
||||||
attention_mask = self.expand_dims(attention_mask, 1)
|
attention_mask = self.expand_dims(attention_mask, 1)
|
||||||
|
@ -606,7 +616,8 @@ class BertSelfAttention(nn.Cell):
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
use_relative_positions=False,
|
use_relative_positions=False,
|
||||||
compute_type=mstype.float32):
|
compute_type=mstype.float32,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
super(BertSelfAttention, self).__init__()
|
super(BertSelfAttention, self).__init__()
|
||||||
if hidden_size % num_attention_heads != 0:
|
if hidden_size % num_attention_heads != 0:
|
||||||
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||||
|
@ -634,7 +645,8 @@ class BertSelfAttention(nn.Cell):
|
||||||
out_channels=hidden_size,
|
out_channels=hidden_size,
|
||||||
initializer_range=initializer_range,
|
initializer_range=initializer_range,
|
||||||
dropout_prob=hidden_dropout_prob,
|
dropout_prob=hidden_dropout_prob,
|
||||||
compute_type=compute_type)
|
compute_type=compute_type,
|
||||||
|
enable_fused_layernorm=enable_fused_layernorm)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.shape = (-1, hidden_size)
|
self.shape = (-1, hidden_size)
|
||||||
|
|
||||||
|
@ -676,7 +688,8 @@ class BertEncoderCell(nn.Cell):
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
use_relative_positions=False,
|
use_relative_positions=False,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
compute_type=mstype.float32):
|
compute_type=mstype.float32,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
super(BertEncoderCell, self).__init__()
|
super(BertEncoderCell, self).__init__()
|
||||||
self.attention = BertSelfAttention(
|
self.attention = BertSelfAttention(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -688,7 +701,8 @@ class BertEncoderCell(nn.Cell):
|
||||||
initializer_range=initializer_range,
|
initializer_range=initializer_range,
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
hidden_dropout_prob=hidden_dropout_prob,
|
||||||
use_relative_positions=use_relative_positions,
|
use_relative_positions=use_relative_positions,
|
||||||
compute_type=compute_type)
|
compute_type=compute_type,
|
||||||
|
enable_fused_layernorm=enable_fused_layernorm)
|
||||||
self.intermediate = nn.Dense(in_channels=hidden_size,
|
self.intermediate = nn.Dense(in_channels=hidden_size,
|
||||||
out_channels=intermediate_size,
|
out_channels=intermediate_size,
|
||||||
activation=hidden_act,
|
activation=hidden_act,
|
||||||
|
@ -697,7 +711,8 @@ class BertEncoderCell(nn.Cell):
|
||||||
out_channels=hidden_size,
|
out_channels=hidden_size,
|
||||||
initializer_range=initializer_range,
|
initializer_range=initializer_range,
|
||||||
dropout_prob=hidden_dropout_prob,
|
dropout_prob=hidden_dropout_prob,
|
||||||
compute_type=compute_type)
|
compute_type=compute_type,
|
||||||
|
enable_fused_layernorm=enable_fused_layernorm)
|
||||||
|
|
||||||
def construct(self, hidden_states, attention_mask):
|
def construct(self, hidden_states, attention_mask):
|
||||||
# self-attention
|
# self-attention
|
||||||
|
@ -744,7 +759,8 @@ class BertTransformer(nn.Cell):
|
||||||
use_relative_positions=False,
|
use_relative_positions=False,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
compute_type=mstype.float32,
|
compute_type=mstype.float32,
|
||||||
return_all_encoders=False):
|
return_all_encoders=False,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
super(BertTransformer, self).__init__()
|
super(BertTransformer, self).__init__()
|
||||||
self.return_all_encoders = return_all_encoders
|
self.return_all_encoders = return_all_encoders
|
||||||
|
|
||||||
|
@ -761,7 +777,8 @@ class BertTransformer(nn.Cell):
|
||||||
hidden_dropout_prob=hidden_dropout_prob,
|
hidden_dropout_prob=hidden_dropout_prob,
|
||||||
use_relative_positions=use_relative_positions,
|
use_relative_positions=use_relative_positions,
|
||||||
hidden_act=hidden_act,
|
hidden_act=hidden_act,
|
||||||
compute_type=compute_type)
|
compute_type=compute_type,
|
||||||
|
enable_fused_layernorm=enable_fused_layernorm)
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
|
|
||||||
self.layers = nn.CellList(layers)
|
self.layers = nn.CellList(layers)
|
||||||
|
@ -888,7 +905,8 @@ class BertModel(nn.Cell):
|
||||||
use_relative_positions=config.use_relative_positions,
|
use_relative_positions=config.use_relative_positions,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
compute_type=config.compute_type,
|
compute_type=config.compute_type,
|
||||||
return_all_encoders=True)
|
return_all_encoders=True,
|
||||||
|
enable_fused_layernorm=config.enable_fused_layernorm)
|
||||||
|
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.dtype = config.dtype
|
self.dtype = config.dtype
|
|
@ -17,12 +17,12 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from evaluation_config import cfg
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from CRF import postprocess
|
|
||||||
import tokenization
|
import tokenization
|
||||||
from sample_process import label_generation, process_one_example_p
|
from sample_process import label_generation, process_one_example_p
|
||||||
|
from .evaluation_config import cfg
|
||||||
|
from .CRF import postprocess
|
||||||
|
|
||||||
vocab_file = "./vocab.txt"
|
vocab_file = "./vocab.txt"
|
||||||
tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)
|
tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)
|
|
@ -17,16 +17,16 @@ network config setting, will be used in dataset.py, run_pretrain.py
|
||||||
"""
|
"""
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
from .bert_model import BertConfig
|
||||||
cfg = edict({
|
cfg = edict({
|
||||||
'bert_network': 'base',
|
'bert_network': 'base',
|
||||||
'loss_scale_value': 2**32,
|
'loss_scale_value': 65536,
|
||||||
'scale_factor': 2,
|
'scale_factor': 2,
|
||||||
'scale_window': 1000,
|
'scale_window': 1000,
|
||||||
'optimizer': 'Lamb',
|
'optimizer': 'Lamb',
|
||||||
'AdamWeightDecayDynamicLR': edict({
|
'AdamWeightDecayDynamicLR': edict({
|
||||||
'learning_rate': 3e-5,
|
'learning_rate': 3e-5,
|
||||||
'end_learning_rate': 1e-7,
|
'end_learning_rate': 1e-10,
|
||||||
'power': 5.0,
|
'power': 5.0,
|
||||||
'weight_decay': 1e-5,
|
'weight_decay': 1e-5,
|
||||||
'eps': 1e-6,
|
'eps': 1e-6,
|
||||||
|
@ -34,7 +34,7 @@ cfg = edict({
|
||||||
}),
|
}),
|
||||||
'Lamb': edict({
|
'Lamb': edict({
|
||||||
'start_learning_rate': 3e-5,
|
'start_learning_rate': 3e-5,
|
||||||
'end_learning_rate': 1e-7,
|
'end_learning_rate': 1e-10,
|
||||||
'power': 10.0,
|
'power': 10.0,
|
||||||
'warmup_steps': 10000,
|
'warmup_steps': 10000,
|
||||||
'weight_decay': 0.01,
|
'weight_decay': 0.01,
|
||||||
|
@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
seq_length=128,
|
seq_length=128,
|
||||||
vocab_size=21128,
|
vocab_size=21136,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
num_hidden_layers=12,
|
num_hidden_layers=12,
|
||||||
num_attention_heads=12,
|
num_attention_heads=12,
|
||||||
|
@ -71,13 +71,13 @@ if cfg.bert_network == 'base':
|
||||||
input_mask_from_dataset=True,
|
input_mask_from_dataset=True,
|
||||||
token_type_ids_from_dataset=True,
|
token_type_ids_from_dataset=True,
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16,
|
compute_type=mstype.float16
|
||||||
)
|
)
|
||||||
if cfg.bert_network == 'nezha':
|
if cfg.bert_network == 'nezha':
|
||||||
bert_net_cfg = BertConfig(
|
bert_net_cfg = BertConfig(
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
seq_length=128,
|
seq_length=128,
|
||||||
vocab_size=21128,
|
vocab_size=21136,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
num_hidden_layers=24,
|
num_hidden_layers=24,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
|
@ -92,5 +92,27 @@ if cfg.bert_network == 'nezha':
|
||||||
input_mask_from_dataset=True,
|
input_mask_from_dataset=True,
|
||||||
token_type_ids_from_dataset=True,
|
token_type_ids_from_dataset=True,
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16,
|
compute_type=mstype.float16
|
||||||
|
)
|
||||||
|
if cfg.bert_network == 'large':
|
||||||
|
bert_net_cfg = BertConfig(
|
||||||
|
batch_size=16,
|
||||||
|
seq_length=512,
|
||||||
|
vocab_size=30528,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=16,
|
||||||
|
intermediate_size=4096,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16,
|
||||||
|
enable_fused_layernorm=True
|
||||||
)
|
)
|
|
@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
|
||||||
import mindspore.dataset.engine.datasets as de
|
import mindspore.dataset.engine.datasets as de
|
||||||
import mindspore.dataset.transforms.c_transforms as C
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from config import bert_net_cfg
|
from .config import bert_net_cfg
|
||||||
|
|
||||||
|
|
||||||
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
|
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
|
||||||
|
@ -31,8 +31,9 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
||||||
files = os.listdir(data_dir)
|
files = os.listdir(data_dir)
|
||||||
data_files = []
|
data_files = []
|
||||||
for file_name in files:
|
for file_name in files:
|
||||||
data_files.append(os.path.join(data_dir, file_name))
|
if "tfrecord" in file_name:
|
||||||
ds = de.TFRecordDataset(data_files, schema_dir,
|
data_files.append(os.path.join(data_dir, file_name))
|
||||||
|
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
|
||||||
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||||
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
|
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
|
||||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
|
@ -19,7 +19,7 @@ config settings, will be used in finetune.py
|
||||||
|
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
from .bert_model import BertConfig
|
||||||
|
|
||||||
cfg = edict({
|
cfg = edict({
|
||||||
'task': 'NER',
|
'task': 'NER',
|
|
@ -19,7 +19,7 @@ config settings, will be used in finetune.py
|
||||||
|
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
from .bert_model import BertConfig
|
||||||
|
|
||||||
cfg = edict({
|
cfg = edict({
|
||||||
'task': 'NER',
|
'task': 'NER',
|
|
@ -0,0 +1,121 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""fused layernorm"""
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.nn.cell import Cell
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['FusedLayerNorm']
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def get_shape_for_norm(x_shape, begin_norm_axis):
|
||||||
|
print("input_shape: ", x_shape)
|
||||||
|
norm_shape = x_shape[begin_norm_axis:]
|
||||||
|
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
|
||||||
|
print("output_shape: ", output_shape)
|
||||||
|
return output_shape
|
||||||
|
|
||||||
|
class FusedLayerNorm(Cell):
|
||||||
|
r"""
|
||||||
|
Applies Layer Normalization over a mini-batch of inputs.
|
||||||
|
|
||||||
|
Layer normalization is widely used in recurrent neural networks. It applies
|
||||||
|
normalization over a mini-batch of inputs for each single training case as described
|
||||||
|
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
|
||||||
|
normalization, layer normalization performs exactly the same computation at training and
|
||||||
|
testing times. It can be described using the following formula. It is applied across all channels
|
||||||
|
and pixel but only one batch size.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
|
||||||
|
`begin_norm_axis ... R - 1`.
|
||||||
|
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
|
||||||
|
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
|
||||||
|
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
|
||||||
|
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
|
||||||
|
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
|
||||||
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||||
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||||
|
'he_uniform', etc. Default: 'ones'.
|
||||||
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||||
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||||
|
'he_uniform', etc. Default: 'zeros'.
|
||||||
|
use_batch_nrom (bool): Whether use batchnorm to preocess.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
|
||||||
|
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
|
||||||
|
>>> shape1 = x.shape()[1:]
|
||||||
|
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
|
||||||
|
>>> m(x)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
normalized_shape,
|
||||||
|
begin_norm_axis=-1,
|
||||||
|
begin_params_axis=-1,
|
||||||
|
gamma_init='ones',
|
||||||
|
beta_init='zeros',
|
||||||
|
use_batch_norm=False):
|
||||||
|
super(FusedLayerNorm, self).__init__()
|
||||||
|
if not isinstance(normalized_shape, (tuple, list)):
|
||||||
|
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
|
||||||
|
.format(normalized_shape, type(normalized_shape)))
|
||||||
|
self.normalized_shape = normalized_shape
|
||||||
|
self.begin_norm_axis = begin_norm_axis
|
||||||
|
self.begin_params_axis = begin_params_axis
|
||||||
|
self.gamma = Parameter(initializer(
|
||||||
|
gamma_init, normalized_shape), name="gamma")
|
||||||
|
self.beta = Parameter(initializer(
|
||||||
|
beta_init, normalized_shape), name="beta")
|
||||||
|
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
|
||||||
|
|
||||||
|
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
if self.use_batch_norm and self.training:
|
||||||
|
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
|
||||||
|
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
|
||||||
|
shape_x = F.shape(input_x)
|
||||||
|
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
|
||||||
|
input_x = F.reshape(input_x, norm_shape)
|
||||||
|
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
|
||||||
|
output = F.reshape(output, shape_x)
|
||||||
|
y = output * self.gamma + self.beta
|
||||||
|
else:
|
||||||
|
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
|
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
|
||||||
|
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
|
||||||
|
return s
|
|
@ -30,8 +30,8 @@ from mindspore.train.parallel_utils import ParallelMode
|
||||||
from mindspore.communication.management import get_group_size
|
from mindspore.communication.management import get_group_size
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
|
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
|
||||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
|
from .bert_for_pre_training import clip_grad
|
||||||
from CRF import CRF
|
from .CRF import CRF
|
||||||
|
|
||||||
GRADIENT_CLIP_TYPE = 1
|
GRADIENT_CLIP_TYPE = 1
|
||||||
GRADIENT_CLIP_VALUE = 1.0
|
GRADIENT_CLIP_VALUE = 1.0
|
|
@ -25,7 +25,8 @@ import mindspore.dataset.transforms.c_transforms as C
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
|
from src.bert_model import BertConfig
|
||||||
|
from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
|
||||||
from mindspore.nn.optim import Lamb
|
from mindspore.nn.optim import Lamb
|
||||||
from mindspore.train.callback import Callback
|
from mindspore.train.callback import Callback
|
||||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||||
|
@ -77,7 +78,8 @@ def get_config(version='base', batch_size=1):
|
||||||
input_mask_from_dataset=True,
|
input_mask_from_dataset=True,
|
||||||
token_type_ids_from_dataset=True,
|
token_type_ids_from_dataset=True,
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float16)
|
compute_type=mstype.float16,
|
||||||
|
enable_fused_layernorm=False)
|
||||||
else:
|
else:
|
||||||
bert_config = BertConfig(batch_size=batch_size)
|
bert_config = BertConfig(batch_size=batch_size)
|
||||||
return bert_config
|
return bert_config
|
||||||
|
|
|
@ -0,0 +1,177 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
'''
|
||||||
|
CRF script.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
class CRF(nn.Cell):
|
||||||
|
'''
|
||||||
|
Conditional Random Field
|
||||||
|
Args:
|
||||||
|
tag_to_index: The dict for tag to index mapping with extra "<START>" and "<STOP>"sign.
|
||||||
|
batch_size: Batch size, i.e., the length of the first dimension.
|
||||||
|
seq_length: Sequence length, i.e., the length of the second dimention.
|
||||||
|
is_training: Specifies whether to use training mode.
|
||||||
|
Returns:
|
||||||
|
Training mode: Tensor, total loss.
|
||||||
|
Evaluation mode: Tuple, the index for each step with the highest score; Tuple, the index for the last
|
||||||
|
step with the highest score.
|
||||||
|
'''
|
||||||
|
def __init__(self, tag_to_index, batch_size=1, seq_length=128, is_training=True):
|
||||||
|
|
||||||
|
super(CRF, self).__init__()
|
||||||
|
self.target_size = len(tag_to_index)
|
||||||
|
self.is_training = is_training
|
||||||
|
self.tag_to_index = tag_to_index
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.START_TAG = "<START>"
|
||||||
|
self.STOP_TAG = "<STOP>"
|
||||||
|
self.START_VALUE = Tensor(self.target_size-2, dtype=mstype.int32)
|
||||||
|
self.STOP_VALUE = Tensor(self.target_size-1, dtype=mstype.int32)
|
||||||
|
transitions = np.random.normal(size=(self.target_size, self.target_size)).astype(np.float32)
|
||||||
|
transitions[tag_to_index[self.START_TAG], :] = -10000
|
||||||
|
transitions[:, tag_to_index[self.STOP_TAG]] = -10000
|
||||||
|
self.transitions = Parameter(Tensor(transitions), name="transition_matrix")
|
||||||
|
self.cat = P.Concat(axis=-1)
|
||||||
|
self.argmax = P.ArgMaxWithValue(axis=-1)
|
||||||
|
self.log = P.Log()
|
||||||
|
self.exp = P.Exp()
|
||||||
|
self.sum = P.ReduceSum()
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.expand = P.ExpandDims()
|
||||||
|
self.mean = P.ReduceMean()
|
||||||
|
init_alphas = np.ones(shape=(self.batch_size, self.target_size)) * -10000.0
|
||||||
|
init_alphas[:, self.tag_to_index[self.START_TAG]] = 0.
|
||||||
|
self.init_alphas = Tensor(init_alphas, dtype=mstype.float32)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.reduce_max = P.ReduceMax(keep_dims=True)
|
||||||
|
self.on_value = Tensor(1.0, dtype=mstype.float32)
|
||||||
|
self.off_value = Tensor(0.0, dtype=mstype.float32)
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
|
||||||
|
def log_sum_exp(self, logits):
|
||||||
|
'''
|
||||||
|
Compute the log_sum_exp score for normalization factor.
|
||||||
|
'''
|
||||||
|
max_score = self.reduce_max(logits, -1) #16 5 5
|
||||||
|
score = self.log(self.reduce_sum(self.exp(logits - max_score), -1))
|
||||||
|
score = max_score + score
|
||||||
|
return score
|
||||||
|
|
||||||
|
def _realpath_score(self, features, label):
|
||||||
|
'''
|
||||||
|
Compute the emission and transition score for the real path.
|
||||||
|
'''
|
||||||
|
label = label * 1
|
||||||
|
concat_A = self.tile(self.reshape(self.START_VALUE, (1,)), (self.batch_size,))
|
||||||
|
concat_A = self.reshape(concat_A, (self.batch_size, 1))
|
||||||
|
labels = self.cat((concat_A, label))
|
||||||
|
onehot_label = self.onehot(label, self.target_size, self.on_value, self.off_value)
|
||||||
|
emits = features * onehot_label
|
||||||
|
labels = self.onehot(labels, self.target_size, self.on_value, self.off_value)
|
||||||
|
label1 = labels[:, 1:, :]
|
||||||
|
label2 = labels[:, :self.seq_length, :]
|
||||||
|
label1 = self.expand(label1, 3)
|
||||||
|
label2 = self.expand(label2, 2)
|
||||||
|
label_trans = label1 * label2
|
||||||
|
transitions = self.expand(self.expand(self.transitions, 0), 0)
|
||||||
|
trans = transitions * label_trans
|
||||||
|
score = self.sum(emits, (1, 2)) + self.sum(trans, (1, 2, 3))
|
||||||
|
stop_value_index = labels[:, (self.seq_length-1):self.seq_length, :]
|
||||||
|
stop_value = self.transitions[(self.target_size-1):self.target_size, :]
|
||||||
|
stop_score = stop_value * self.reshape(stop_value_index, (self.batch_size, self.target_size))
|
||||||
|
score = score + self.sum(stop_score, 1)
|
||||||
|
score = self.reshape(score, (self.batch_size, -1))
|
||||||
|
return score
|
||||||
|
|
||||||
|
def _normalization_factor(self, features):
|
||||||
|
'''
|
||||||
|
Compute the total score for all the paths.
|
||||||
|
'''
|
||||||
|
forward_var = self.init_alphas
|
||||||
|
forward_var = self.expand(forward_var, 1)
|
||||||
|
for idx in range(self.seq_length):
|
||||||
|
feat = features[:, idx:(idx+1), :]
|
||||||
|
emit_score = self.reshape(feat, (self.batch_size, self.target_size, 1))
|
||||||
|
next_tag_var = emit_score + self.transitions + forward_var
|
||||||
|
forward_var = self.log_sum_exp(next_tag_var)
|
||||||
|
forward_var = self.reshape(forward_var, (self.batch_size, 1, self.target_size))
|
||||||
|
terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
|
||||||
|
alpha = self.log_sum_exp(terminal_var)
|
||||||
|
alpha = self.reshape(alpha, (self.batch_size, -1))
|
||||||
|
return alpha
|
||||||
|
|
||||||
|
def _decoder(self, features):
|
||||||
|
'''
|
||||||
|
Viterbi decode for evaluation.
|
||||||
|
'''
|
||||||
|
backpointers = ()
|
||||||
|
forward_var = self.init_alphas
|
||||||
|
for idx in range(self.seq_length):
|
||||||
|
feat = features[:, idx:(idx+1), :]
|
||||||
|
feat = self.reshape(feat, (self.batch_size, self.target_size))
|
||||||
|
bptrs_t = ()
|
||||||
|
|
||||||
|
next_tag_var = self.expand(forward_var, 1) + self.transitions
|
||||||
|
best_tag_id, best_tag_value = self.argmax(next_tag_var)
|
||||||
|
bptrs_t += (best_tag_id,)
|
||||||
|
forward_var = best_tag_value + feat
|
||||||
|
|
||||||
|
backpointers += (bptrs_t,)
|
||||||
|
terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
|
||||||
|
best_tag_id, _ = self.argmax(terminal_var)
|
||||||
|
return backpointers, best_tag_id
|
||||||
|
|
||||||
|
def construct(self, features, label):
|
||||||
|
if self.is_training:
|
||||||
|
forward_score = self._normalization_factor(features)
|
||||||
|
gold_score = self._realpath_score(features, label)
|
||||||
|
return_value = self.mean(forward_score - gold_score)
|
||||||
|
else:
|
||||||
|
path_list, tag = self._decoder(features)
|
||||||
|
return_value = path_list, tag
|
||||||
|
return return_value
|
||||||
|
|
||||||
|
def postprocess(backpointers, best_tag_id):
|
||||||
|
'''
|
||||||
|
Do postprocess
|
||||||
|
'''
|
||||||
|
best_tag_id = best_tag_id.asnumpy()
|
||||||
|
batch_size = len(best_tag_id)
|
||||||
|
best_path = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
best_path.append([])
|
||||||
|
best_local_id = best_tag_id[i]
|
||||||
|
best_path[-1].append(best_local_id)
|
||||||
|
for bptrs_t in reversed(backpointers):
|
||||||
|
bptrs_t = bptrs_t[0].asnumpy()
|
||||||
|
local_idx = bptrs_t[i]
|
||||||
|
best_local_id = local_idx[best_local_id]
|
||||||
|
best_path[-1].append(best_local_id)
|
||||||
|
# Pop off the start tag (we dont want to return that to the caller)
|
||||||
|
best_path[-1].pop()
|
||||||
|
best_path[-1].reverse()
|
||||||
|
return best_path
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Bert Init."""
|
||||||
|
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
|
||||||
|
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
|
||||||
|
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||||
|
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
||||||
|
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
||||||
|
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
||||||
|
SaturateCast, CreateAttentionMaskFromInputMask
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
||||||
|
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell",
|
||||||
|
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
||||||
|
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
||||||
|
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
|
||||||
|
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
|
||||||
|
]
|
|
@ -0,0 +1,434 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Bert for pretraining."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||||
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
from mindspore.communication.management import get_group_size
|
||||||
|
from mindspore import context
|
||||||
|
from .bert_model import BertModel
|
||||||
|
|
||||||
|
GRADIENT_CLIP_TYPE = 1
|
||||||
|
GRADIENT_CLIP_VALUE = 1.0
|
||||||
|
|
||||||
|
_nn_clip_by_norm = nn.ClipByNorm()
|
||||||
|
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||||
|
@clip_grad.register("Number", "Number", "Tensor")
|
||||||
|
def _clip_grad(clip_type, clip_value, grad):
|
||||||
|
"""
|
||||||
|
Clip gradients.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||||
|
clip_value (float): Specifies how much to clip.
|
||||||
|
grad (tuple[Tensor]): Gradients.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
tuple[Tensor], clipped gradients.
|
||||||
|
"""
|
||||||
|
if clip_type != 0 and clip_type != 1:
|
||||||
|
return grad
|
||||||
|
dt = F.dtype(grad)
|
||||||
|
if clip_type == 0:
|
||||||
|
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||||
|
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
|
else:
|
||||||
|
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||||
|
return new_grad
|
||||||
|
|
||||||
|
class GetMaskedLMOutput(nn.Cell):
|
||||||
|
"""
|
||||||
|
Get masked lm output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (BertConfig): The config of BertModel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, masked lm output.
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(GetMaskedLMOutput, self).__init__()
|
||||||
|
self.width = config.hidden_size
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
|
||||||
|
weight_init = TruncatedNormal(config.initializer_range)
|
||||||
|
self.dense = nn.Dense(self.width,
|
||||||
|
config.hidden_size,
|
||||||
|
weight_init=weight_init,
|
||||||
|
activation=config.hidden_act).to_float(config.compute_type)
|
||||||
|
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
|
||||||
|
self.output_bias = Parameter(
|
||||||
|
initializer(
|
||||||
|
'zero',
|
||||||
|
config.vocab_size),
|
||||||
|
name='output_bias')
|
||||||
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
|
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||||
|
self.shape_flat_offsets = (-1, 1)
|
||||||
|
self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32))
|
||||||
|
self.last_idx = (-1,)
|
||||||
|
self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width)
|
||||||
|
self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32))
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.compute_type = config.compute_type
|
||||||
|
self.dtype = config.dtype
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
input_tensor,
|
||||||
|
output_weights,
|
||||||
|
positions):
|
||||||
|
flat_offsets = self.reshape(
|
||||||
|
self.rng * self.seq_length_tensor, self.shape_flat_offsets)
|
||||||
|
flat_position = self.reshape(positions + flat_offsets, self.last_idx)
|
||||||
|
flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
|
||||||
|
input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
|
||||||
|
input_tensor = self.cast(input_tensor, self.compute_type)
|
||||||
|
output_weights = self.cast(output_weights, self.compute_type)
|
||||||
|
input_tensor = self.dense(input_tensor)
|
||||||
|
input_tensor = self.layernorm(input_tensor)
|
||||||
|
logits = self.matmul(input_tensor, output_weights)
|
||||||
|
logits = self.cast(logits, self.dtype)
|
||||||
|
logits = logits + self.output_bias
|
||||||
|
log_probs = self.log_softmax(logits)
|
||||||
|
return log_probs
|
||||||
|
|
||||||
|
|
||||||
|
class GetNextSentenceOutput(nn.Cell):
|
||||||
|
"""
|
||||||
|
Get next sentence output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (BertConfig): The config of Bert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, next sentence output.
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(GetNextSentenceOutput, self).__init__()
|
||||||
|
self.log_softmax = P.LogSoftmax()
|
||||||
|
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||||
|
self.dense = nn.Dense(config.hidden_size, 2,
|
||||||
|
weight_init=self.weight_init, has_bias=True).to_float(config.compute_type)
|
||||||
|
self.dtype = config.dtype
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def construct(self, input_tensor):
|
||||||
|
logits = self.dense(input_tensor)
|
||||||
|
logits = self.cast(logits, self.dtype)
|
||||||
|
log_prob = self.log_softmax(logits)
|
||||||
|
return log_prob
|
||||||
|
|
||||||
|
|
||||||
|
class BertPreTraining(nn.Cell):
|
||||||
|
"""
|
||||||
|
Bert pretraining network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (BertConfig): The config of BertModel.
|
||||||
|
is_training (bool): Specifies whether to use the training mode.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, prediction_scores, seq_relationship_score.
|
||||||
|
"""
|
||||||
|
def __init__(self, config, is_training, use_one_hot_embeddings):
|
||||||
|
super(BertPreTraining, self).__init__()
|
||||||
|
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||||
|
self.cls1 = GetMaskedLMOutput(config)
|
||||||
|
self.cls2 = GetNextSentenceOutput(config)
|
||||||
|
|
||||||
|
def construct(self, input_ids, input_mask, token_type_id,
|
||||||
|
masked_lm_positions):
|
||||||
|
sequence_output, pooled_output, embedding_table = \
|
||||||
|
self.bert(input_ids, token_type_id, input_mask)
|
||||||
|
prediction_scores = self.cls1(sequence_output,
|
||||||
|
embedding_table,
|
||||||
|
masked_lm_positions)
|
||||||
|
seq_relationship_score = self.cls2(pooled_output)
|
||||||
|
return prediction_scores, seq_relationship_score
|
||||||
|
|
||||||
|
|
||||||
|
class BertPretrainingLoss(nn.Cell):
|
||||||
|
"""
|
||||||
|
Provide bert pre-training loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (BertConfig): The config of BertModel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, total loss.
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BertPretrainingLoss, self).__init__()
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
|
self.off_value = Tensor(0.0, mstype.float32)
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.reduce_mean = P.ReduceMean()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.last_idx = (-1,)
|
||||||
|
self.neg = P.Neg()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids,
|
||||||
|
masked_lm_weights, next_sentence_labels):
|
||||||
|
"""Defines the computation performed."""
|
||||||
|
label_ids = self.reshape(masked_lm_ids, self.last_idx)
|
||||||
|
label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32)
|
||||||
|
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||||
|
|
||||||
|
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
||||||
|
numerator = self.reduce_sum(label_weights * per_example_loss, ())
|
||||||
|
denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
|
||||||
|
masked_lm_loss = numerator / denominator
|
||||||
|
|
||||||
|
# next_sentence_loss
|
||||||
|
labels = self.reshape(next_sentence_labels, self.last_idx)
|
||||||
|
one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value)
|
||||||
|
per_example_loss = self.neg(self.reduce_sum(
|
||||||
|
one_hot_labels * seq_relationship_score, self.last_idx))
|
||||||
|
next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||||
|
|
||||||
|
# total_loss
|
||||||
|
total_loss = masked_lm_loss + next_sentence_loss
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
|
class BertNetworkWithLoss(nn.Cell):
|
||||||
|
"""
|
||||||
|
Provide bert pre-training loss through network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (BertConfig): The config of BertModel.
|
||||||
|
is_training (bool): Specifies whether to use the training mode.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, the loss of the network.
|
||||||
|
"""
|
||||||
|
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||||
|
super(BertNetworkWithLoss, self).__init__()
|
||||||
|
self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings)
|
||||||
|
self.loss = BertPretrainingLoss(config)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights):
|
||||||
|
prediction_scores, seq_relationship_score = \
|
||||||
|
self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
|
||||||
|
total_loss = self.loss(prediction_scores, seq_relationship_score,
|
||||||
|
masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||||
|
return self.cast(total_loss, mstype.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class BertTrainOneStepCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Encapsulation class of bert network training.
|
||||||
|
|
||||||
|
Append an optimizer to the training network after that the construct
|
||||||
|
function can be called to create the backward graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network. Note that loss function should have been added.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the weights.
|
||||||
|
sens (Number): The adjust parameter. Default: 1.0.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
|
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||||
|
self.sens = sens
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||||
|
self.reducer_flag = True
|
||||||
|
self.grad_reducer = None
|
||||||
|
if self.reducer_flag:
|
||||||
|
mean = context.get_auto_parallel_context("mirror_mean")
|
||||||
|
degree = get_group_size()
|
||||||
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
||||||
|
def set_sens(self, value):
|
||||||
|
self.sens = value
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights):
|
||||||
|
"""Defines the computation performed."""
|
||||||
|
weights = self.weights
|
||||||
|
|
||||||
|
loss = self.network(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights)
|
||||||
|
grads = self.grad(self.network, weights)(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights,
|
||||||
|
self.cast(F.tuple_to_array((self.sens,)),
|
||||||
|
mstype.float32))
|
||||||
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
|
if self.reducer_flag:
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
|
||||||
|
succ = self.optimizer(grads)
|
||||||
|
return F.depend(loss, succ)
|
||||||
|
|
||||||
|
|
||||||
|
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||||
|
reciprocal = P.Reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
@grad_scale.register("Tensor", "Tensor")
|
||||||
|
def tensor_grad_scale(scale, grad):
|
||||||
|
return grad * reciprocal(scale)
|
||||||
|
|
||||||
|
|
||||||
|
class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Encapsulation class of bert network training.
|
||||||
|
|
||||||
|
Append an optimizer to the training network after that the construct
|
||||||
|
function can be called to create the backward graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network. Note that loss function should have been added.
|
||||||
|
optimizer (Optimizer): Optimizer for updating the weights.
|
||||||
|
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||||
|
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation('grad',
|
||||||
|
get_by_list=True,
|
||||||
|
sens_param=True)
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.allreduce = P.AllReduce()
|
||||||
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||||
|
self.reducer_flag = True
|
||||||
|
self.grad_reducer = F.identity
|
||||||
|
self.degree = 1
|
||||||
|
if self.reducer_flag:
|
||||||
|
self.degree = get_group_size()
|
||||||
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||||
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||||
|
self.base = Tensor(1, mstype.float32)
|
||||||
|
self.less_equal = P.LessEqual()
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.loss_scale = None
|
||||||
|
self.loss_scaling_manager = scale_update_cell
|
||||||
|
if scale_update_cell:
|
||||||
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||||
|
name="loss_scale")
|
||||||
|
self.add_flags(has_effect=True)
|
||||||
|
def construct(self,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights,
|
||||||
|
sens=None):
|
||||||
|
"""Defines the computation performed."""
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights)
|
||||||
|
if sens is None:
|
||||||
|
scaling_sens = self.loss_scale
|
||||||
|
else:
|
||||||
|
scaling_sens = sens
|
||||||
|
# alloc status and clear should be right before gradoperation
|
||||||
|
init = self.alloc_status()
|
||||||
|
self.clear_before_grad(init)
|
||||||
|
grads = self.grad(self.network, weights)(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
next_sentence_labels,
|
||||||
|
masked_lm_positions,
|
||||||
|
masked_lm_ids,
|
||||||
|
masked_lm_weights,
|
||||||
|
self.cast(scaling_sens,
|
||||||
|
mstype.float32))
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||||
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
|
self.get_status(init)
|
||||||
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
|
if self.is_distributed:
|
||||||
|
# sum overflow flag over devices
|
||||||
|
flag_reduce = self.allreduce(flag_sum)
|
||||||
|
cond = self.less_equal(self.base, flag_reduce)
|
||||||
|
else:
|
||||||
|
cond = self.less_equal(self.base, flag_sum)
|
||||||
|
overflow = cond
|
||||||
|
if sens is None:
|
||||||
|
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||||
|
if overflow:
|
||||||
|
succ = False
|
||||||
|
else:
|
||||||
|
succ = self.optimizer(grads)
|
||||||
|
ret = (loss, cond, scaling_sens)
|
||||||
|
return F.depend(ret, succ)
|
|
@ -0,0 +1,949 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Bert model."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops.functional as F
|
||||||
|
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from .fused_layer_norm import FusedLayerNorm
|
||||||
|
|
||||||
|
|
||||||
|
class BertConfig:
|
||||||
|
"""
|
||||||
|
Configuration for `BertModel`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Batch size of input dataset.
|
||||||
|
seq_length (int): Length of input sequence. Default: 128.
|
||||||
|
vocab_size (int): The shape of each embedding vector. Default: 32000.
|
||||||
|
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||||
|
num_hidden_layers (int): Number of hidden layers in the BertTransformer encoder
|
||||||
|
cell. Default: 12.
|
||||||
|
num_attention_heads (int): Number of attention heads in the BertTransformer
|
||||||
|
encoder cell. Default: 12.
|
||||||
|
intermediate_size (int): Size of intermediate layer in the BertTransformer
|
||||||
|
encoder cell. Default: 3072.
|
||||||
|
hidden_act (str): Activation function used in the BertTransformer encoder
|
||||||
|
cell. Default: "gelu".
|
||||||
|
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||||
|
attention_probs_dropout_prob (float): The dropout probability for
|
||||||
|
BertAttention. Default: 0.1.
|
||||||
|
max_position_embeddings (int): Maximum length of sequences used in this
|
||||||
|
model. Default: 512.
|
||||||
|
type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||||
|
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
|
||||||
|
dataset. Default: True.
|
||||||
|
token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded
|
||||||
|
from dataset. Default: True.
|
||||||
|
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||||
|
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
batch_size,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float32,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.input_mask_from_dataset = input_mask_from_dataset
|
||||||
|
self.token_type_ids_from_dataset = token_type_ids_from_dataset
|
||||||
|
self.use_relative_positions = use_relative_positions
|
||||||
|
self.dtype = dtype
|
||||||
|
self.compute_type = compute_type
|
||||||
|
self.enable_fused_layernorm = enable_fused_layernorm
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingLookup(nn.Cell):
|
||||||
|
"""
|
||||||
|
A embeddings lookup table with a fixed dictionary and size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (int): Size of the dictionary of embeddings.
|
||||||
|
embedding_size (int): The size of each embedding vector.
|
||||||
|
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||||
|
each embedding vector.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size,
|
||||||
|
embedding_size,
|
||||||
|
embedding_shape,
|
||||||
|
use_one_hot_embeddings=False,
|
||||||
|
initializer_range=0.02):
|
||||||
|
super(EmbeddingLookup, self).__init__()
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||||
|
self.embedding_table = Parameter(initializer
|
||||||
|
(TruncatedNormal(initializer_range),
|
||||||
|
[vocab_size, embedding_size]),
|
||||||
|
name='embedding_table')
|
||||||
|
self.expand = P.ExpandDims()
|
||||||
|
self.shape_flat = (-1,)
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
self.one_hot = P.OneHot()
|
||||||
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
|
self.off_value = Tensor(0.0, mstype.float32)
|
||||||
|
self.array_mul = P.MatMul()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.shape = tuple(embedding_shape)
|
||||||
|
|
||||||
|
def construct(self, input_ids):
|
||||||
|
extended_ids = self.expand(input_ids, -1)
|
||||||
|
flat_ids = self.reshape(extended_ids, self.shape_flat)
|
||||||
|
if self.use_one_hot_embeddings:
|
||||||
|
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||||
|
output_for_reshape = self.array_mul(
|
||||||
|
one_hot_ids, self.embedding_table)
|
||||||
|
else:
|
||||||
|
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||||
|
output = self.reshape(output_for_reshape, self.shape)
|
||||||
|
return output, self.embedding_table
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingPostprocessor(nn.Cell):
|
||||||
|
"""
|
||||||
|
Postprocessors apply positional and token type embeddings to word embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_size (int): The size of each embedding vector.
|
||||||
|
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||||
|
each embedding vector.
|
||||||
|
use_token_type (bool): Specifies whether to use token type embeddings. Default: False.
|
||||||
|
token_type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
max_position_embeddings (int): Maximum length of sequences used in this
|
||||||
|
model. Default: 512.
|
||||||
|
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
embedding_size,
|
||||||
|
embedding_shape,
|
||||||
|
use_relative_positions=False,
|
||||||
|
use_token_type=False,
|
||||||
|
token_type_vocab_size=16,
|
||||||
|
use_one_hot_embeddings=False,
|
||||||
|
initializer_range=0.02,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
dropout_prob=0.1):
|
||||||
|
super(EmbeddingPostprocessor, self).__init__()
|
||||||
|
self.use_token_type = use_token_type
|
||||||
|
self.token_type_vocab_size = token_type_vocab_size
|
||||||
|
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.embedding_table = Parameter(initializer
|
||||||
|
(TruncatedNormal(initializer_range),
|
||||||
|
[token_type_vocab_size,
|
||||||
|
embedding_size]),
|
||||||
|
name='embedding_table')
|
||||||
|
|
||||||
|
self.shape_flat = (-1,)
|
||||||
|
self.one_hot = P.OneHot()
|
||||||
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
|
self.off_value = Tensor(0.1, mstype.float32)
|
||||||
|
self.array_mul = P.MatMul()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.shape = tuple(embedding_shape)
|
||||||
|
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||||
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
self.use_relative_positions = use_relative_positions
|
||||||
|
self.slice = P.StridedSlice()
|
||||||
|
self.full_position_embeddings = Parameter(initializer
|
||||||
|
(TruncatedNormal(initializer_range),
|
||||||
|
[max_position_embeddings,
|
||||||
|
embedding_size]),
|
||||||
|
name='full_position_embeddings')
|
||||||
|
|
||||||
|
def construct(self, token_type_ids, word_embeddings):
|
||||||
|
output = word_embeddings
|
||||||
|
if self.use_token_type:
|
||||||
|
flat_ids = self.reshape(token_type_ids, self.shape_flat)
|
||||||
|
if self.use_one_hot_embeddings:
|
||||||
|
one_hot_ids = self.one_hot(flat_ids,
|
||||||
|
self.token_type_vocab_size, self.on_value, self.off_value)
|
||||||
|
token_type_embeddings = self.array_mul(one_hot_ids,
|
||||||
|
self.embedding_table)
|
||||||
|
else:
|
||||||
|
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
|
||||||
|
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
|
||||||
|
output += token_type_embeddings
|
||||||
|
if not self.use_relative_positions:
|
||||||
|
_, seq, width = self.shape
|
||||||
|
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
|
||||||
|
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
|
||||||
|
output += position_embeddings
|
||||||
|
output = self.layernorm(output)
|
||||||
|
output = self.dropout(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BertOutput(nn.Cell):
|
||||||
|
"""
|
||||||
|
Apply a linear computation to hidden status and a residual computation to input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): Input channels.
|
||||||
|
out_channels (int): Output channels.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||||
|
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
initializer_range=0.02,
|
||||||
|
dropout_prob=0.1,
|
||||||
|
compute_type=mstype.float32,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
|
super(BertOutput, self).__init__()
|
||||||
|
self.dense = nn.Dense(in_channels, out_channels,
|
||||||
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||||
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||||
|
self.dropout_prob = dropout_prob
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
if compute_type == mstype.float16:
|
||||||
|
self.layernorm = FusedLayerNorm((out_channels,),
|
||||||
|
use_batch_norm=enable_fused_layernorm).to_float(compute_type)
|
||||||
|
else:
|
||||||
|
self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def construct(self, hidden_status, input_tensor):
|
||||||
|
output = self.dense(hidden_status)
|
||||||
|
output = self.dropout(output)
|
||||||
|
output = self.add(output, input_tensor)
|
||||||
|
output = self.layernorm(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class RelaPosMatrixGenerator(nn.Cell):
|
||||||
|
"""
|
||||||
|
Generates matrix of relative positions between inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length (int): Length of one dim for the matrix to be generated.
|
||||||
|
max_relative_position (int): Max value of relative position.
|
||||||
|
"""
|
||||||
|
def __init__(self, length, max_relative_position):
|
||||||
|
super(RelaPosMatrixGenerator, self).__init__()
|
||||||
|
self._length = length
|
||||||
|
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
|
||||||
|
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
|
||||||
|
self.range_length = -length + 1
|
||||||
|
|
||||||
|
self.tile = P.Tile()
|
||||||
|
self.range_mat = P.Reshape()
|
||||||
|
self.sub = P.Sub()
|
||||||
|
self.expanddims = P.ExpandDims()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
|
||||||
|
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
|
||||||
|
tile_row_out = self.tile(range_vec_row_out, (self._length,))
|
||||||
|
tile_col_out = self.tile(range_vec_col_out, (1, self._length))
|
||||||
|
range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
|
||||||
|
transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
|
||||||
|
distance_mat = self.sub(range_mat_out, transpose_out)
|
||||||
|
|
||||||
|
distance_mat_clipped = C.clip_by_value(distance_mat,
|
||||||
|
self._min_relative_position,
|
||||||
|
self._max_relative_position)
|
||||||
|
|
||||||
|
# Shift values to be >=0. Each integer still uniquely identifies a
|
||||||
|
# relative position difference.
|
||||||
|
final_mat = distance_mat_clipped + self._max_relative_position
|
||||||
|
return final_mat
|
||||||
|
|
||||||
|
|
||||||
|
class RelaPosEmbeddingsGenerator(nn.Cell):
|
||||||
|
"""
|
||||||
|
Generates tensor of size [length, length, depth].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length (int): Length of one dim for the matrix to be generated.
|
||||||
|
depth (int): Size of each attention head.
|
||||||
|
max_relative_position (int): Maxmum value of relative position.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
length,
|
||||||
|
depth,
|
||||||
|
max_relative_position,
|
||||||
|
initializer_range,
|
||||||
|
use_one_hot_embeddings=False):
|
||||||
|
super(RelaPosEmbeddingsGenerator, self).__init__()
|
||||||
|
self.depth = depth
|
||||||
|
self.vocab_size = max_relative_position * 2 + 1
|
||||||
|
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||||
|
|
||||||
|
self.embeddings_table = Parameter(
|
||||||
|
initializer(TruncatedNormal(initializer_range),
|
||||||
|
[self.vocab_size, self.depth]),
|
||||||
|
name='embeddings_for_position')
|
||||||
|
|
||||||
|
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
||||||
|
max_relative_position=max_relative_position)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.one_hot = P.OneHot()
|
||||||
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
|
self.off_value = Tensor(0.0, mstype.float32)
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.gather = P.GatherV2() # index_select
|
||||||
|
self.matmul = P.BatchMatMul()
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
relative_positions_matrix_out = self.relative_positions_matrix()
|
||||||
|
|
||||||
|
# Generate embedding for each relative position of dimension depth.
|
||||||
|
if self.use_one_hot_embeddings:
|
||||||
|
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
||||||
|
one_hot_relative_positions_matrix = self.one_hot(
|
||||||
|
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
|
||||||
|
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
||||||
|
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
||||||
|
embeddings = self.reshape(embeddings, my_shape)
|
||||||
|
else:
|
||||||
|
embeddings = self.gather(self.embeddings_table,
|
||||||
|
relative_positions_matrix_out, 0)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class SaturateCast(nn.Cell):
|
||||||
|
"""
|
||||||
|
Performs a safe saturating cast. This operation applies proper clamping before casting to prevent
|
||||||
|
the danger that the value will overflow or underflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
|
||||||
|
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
|
||||||
|
"""
|
||||||
|
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
||||||
|
super(SaturateCast, self).__init__()
|
||||||
|
np_type = mstype.dtype_to_nptype(dst_type)
|
||||||
|
min_type = np.finfo(np_type).min
|
||||||
|
max_type = np.finfo(np_type).max
|
||||||
|
|
||||||
|
self.tensor_min_type = Tensor([min_type], dtype=src_type)
|
||||||
|
self.tensor_max_type = Tensor([max_type], dtype=src_type)
|
||||||
|
|
||||||
|
self.min_op = P.Minimum()
|
||||||
|
self.max_op = P.Maximum()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.dst_type = dst_type
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.max_op(x, self.tensor_min_type)
|
||||||
|
out = self.min_op(out, self.tensor_max_type)
|
||||||
|
return self.cast(out, self.dst_type)
|
||||||
|
|
||||||
|
|
||||||
|
class BertAttention(nn.Cell):
|
||||||
|
"""
|
||||||
|
Apply multi-headed attention from "from_tensor" to "to_tensor".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Batch size of input datasets.
|
||||||
|
from_tensor_width (int): Size of last dim of from_tensor.
|
||||||
|
to_tensor_width (int): Size of last dim of to_tensor.
|
||||||
|
from_seq_length (int): Length of from_tensor sequence.
|
||||||
|
to_seq_length (int): Length of to_tensor sequence.
|
||||||
|
num_attention_heads (int): Number of attention heads. Default: 1.
|
||||||
|
size_per_head (int): Size of each attention head. Default: 512.
|
||||||
|
query_act (str): Activation function for the query transform. Default: None.
|
||||||
|
key_act (str): Activation function for the key transform. Default: None.
|
||||||
|
value_act (str): Activation function for the value transform. Default: None.
|
||||||
|
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
|
||||||
|
attention_probs_dropout_prob (float): The dropout probability for
|
||||||
|
BertAttention. Default: 0.0.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
|
||||||
|
tensor. Default: False.
|
||||||
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||||
|
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
batch_size,
|
||||||
|
from_tensor_width,
|
||||||
|
to_tensor_width,
|
||||||
|
from_seq_length,
|
||||||
|
to_seq_length,
|
||||||
|
num_attention_heads=1,
|
||||||
|
size_per_head=512,
|
||||||
|
query_act=None,
|
||||||
|
key_act=None,
|
||||||
|
value_act=None,
|
||||||
|
has_attention_mask=False,
|
||||||
|
attention_probs_dropout_prob=0.0,
|
||||||
|
use_one_hot_embeddings=False,
|
||||||
|
initializer_range=0.02,
|
||||||
|
do_return_2d_tensor=False,
|
||||||
|
use_relative_positions=False,
|
||||||
|
compute_type=mstype.float32):
|
||||||
|
|
||||||
|
super(BertAttention, self).__init__()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.from_seq_length = from_seq_length
|
||||||
|
self.to_seq_length = to_seq_length
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.size_per_head = size_per_head
|
||||||
|
self.has_attention_mask = has_attention_mask
|
||||||
|
self.use_relative_positions = use_relative_positions
|
||||||
|
|
||||||
|
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.shape_from_2d = (-1, from_tensor_width)
|
||||||
|
self.shape_to_2d = (-1, to_tensor_width)
|
||||||
|
weight = TruncatedNormal(initializer_range)
|
||||||
|
units = num_attention_heads * size_per_head
|
||||||
|
self.query_layer = nn.Dense(from_tensor_width,
|
||||||
|
units,
|
||||||
|
activation=query_act,
|
||||||
|
weight_init=weight).to_float(compute_type)
|
||||||
|
self.key_layer = nn.Dense(to_tensor_width,
|
||||||
|
units,
|
||||||
|
activation=key_act,
|
||||||
|
weight_init=weight).to_float(compute_type)
|
||||||
|
self.value_layer = nn.Dense(to_tensor_width,
|
||||||
|
units,
|
||||||
|
activation=value_act,
|
||||||
|
weight_init=weight).to_float(compute_type)
|
||||||
|
|
||||||
|
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
|
||||||
|
self.shape_to = (
|
||||||
|
batch_size, to_seq_length, num_attention_heads, size_per_head)
|
||||||
|
|
||||||
|
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||||
|
self.multiply = P.Mul()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.trans_shape = (0, 2, 1, 3)
|
||||||
|
self.trans_shape_relative = (2, 0, 1, 3)
|
||||||
|
self.trans_shape_position = (1, 2, 0, 3)
|
||||||
|
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
|
||||||
|
self.batch_num = batch_size * num_attention_heads
|
||||||
|
self.matmul = P.BatchMatMul()
|
||||||
|
|
||||||
|
self.softmax = nn.Softmax()
|
||||||
|
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
if self.has_attention_mask:
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.sub = P.Sub()
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.get_dtype = P.DType()
|
||||||
|
if do_return_2d_tensor:
|
||||||
|
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
|
||||||
|
else:
|
||||||
|
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
|
||||||
|
|
||||||
|
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||||
|
if self.use_relative_positions:
|
||||||
|
self._generate_relative_positions_embeddings = \
|
||||||
|
RelaPosEmbeddingsGenerator(length=to_seq_length,
|
||||||
|
depth=size_per_head,
|
||||||
|
max_relative_position=16,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||||
|
|
||||||
|
def construct(self, from_tensor, to_tensor, attention_mask):
|
||||||
|
# reshape 2d/3d input tensors to 2d
|
||||||
|
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
||||||
|
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
||||||
|
query_out = self.query_layer(from_tensor_2d)
|
||||||
|
key_out = self.key_layer(to_tensor_2d)
|
||||||
|
value_out = self.value_layer(to_tensor_2d)
|
||||||
|
|
||||||
|
query_layer = self.reshape(query_out, self.shape_from)
|
||||||
|
query_layer = self.transpose(query_layer, self.trans_shape)
|
||||||
|
key_layer = self.reshape(key_out, self.shape_to)
|
||||||
|
key_layer = self.transpose(key_layer, self.trans_shape)
|
||||||
|
|
||||||
|
attention_scores = self.matmul_trans_b(query_layer, key_layer)
|
||||||
|
|
||||||
|
# use_relative_position, supplementary logic
|
||||||
|
if self.use_relative_positions:
|
||||||
|
# 'relations_keys' = [F|T, F|T, H]
|
||||||
|
relations_keys = self._generate_relative_positions_embeddings()
|
||||||
|
relations_keys = self.cast_compute_type(relations_keys)
|
||||||
|
# query_layer_t is [F, B, N, H]
|
||||||
|
query_layer_t = self.transpose(query_layer, self.trans_shape_relative)
|
||||||
|
# query_layer_r is [F, B * N, H]
|
||||||
|
query_layer_r = self.reshape(query_layer_t,
|
||||||
|
(self.from_seq_length,
|
||||||
|
self.batch_num,
|
||||||
|
self.size_per_head))
|
||||||
|
# key_position_scores is [F, B * N, F|T]
|
||||||
|
key_position_scores = self.matmul_trans_b(query_layer_r,
|
||||||
|
relations_keys)
|
||||||
|
# key_position_scores_r is [F, B, N, F|T]
|
||||||
|
key_position_scores_r = self.reshape(key_position_scores,
|
||||||
|
(self.from_seq_length,
|
||||||
|
self.batch_size,
|
||||||
|
self.num_attention_heads,
|
||||||
|
self.from_seq_length))
|
||||||
|
# key_position_scores_r_t is [B, N, F, F|T]
|
||||||
|
key_position_scores_r_t = self.transpose(key_position_scores_r,
|
||||||
|
self.trans_shape_position)
|
||||||
|
attention_scores = attention_scores + key_position_scores_r_t
|
||||||
|
|
||||||
|
attention_scores = self.multiply(self.scores_mul, attention_scores)
|
||||||
|
|
||||||
|
if self.has_attention_mask:
|
||||||
|
attention_mask = self.expand_dims(attention_mask, 1)
|
||||||
|
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
|
||||||
|
self.cast(attention_mask, self.get_dtype(attention_scores)))
|
||||||
|
|
||||||
|
adder = self.multiply(multiply_out, self.multiply_data)
|
||||||
|
attention_scores = self.add(adder, attention_scores)
|
||||||
|
|
||||||
|
attention_probs = self.softmax(attention_scores)
|
||||||
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
value_layer = self.reshape(value_out, self.shape_to)
|
||||||
|
value_layer = self.transpose(value_layer, self.trans_shape)
|
||||||
|
context_layer = self.matmul(attention_probs, value_layer)
|
||||||
|
|
||||||
|
# use_relative_position, supplementary logic
|
||||||
|
if self.use_relative_positions:
|
||||||
|
# 'relations_values' = [F|T, F|T, H]
|
||||||
|
relations_values = self._generate_relative_positions_embeddings()
|
||||||
|
relations_values = self.cast_compute_type(relations_values)
|
||||||
|
# attention_probs_t is [F, B, N, T]
|
||||||
|
attention_probs_t = self.transpose(attention_probs, self.trans_shape_relative)
|
||||||
|
# attention_probs_r is [F, B * N, T]
|
||||||
|
attention_probs_r = self.reshape(
|
||||||
|
attention_probs_t,
|
||||||
|
(self.from_seq_length,
|
||||||
|
self.batch_num,
|
||||||
|
self.to_seq_length))
|
||||||
|
# value_position_scores is [F, B * N, H]
|
||||||
|
value_position_scores = self.matmul(attention_probs_r,
|
||||||
|
relations_values)
|
||||||
|
# value_position_scores_r is [F, B, N, H]
|
||||||
|
value_position_scores_r = self.reshape(value_position_scores,
|
||||||
|
(self.from_seq_length,
|
||||||
|
self.batch_size,
|
||||||
|
self.num_attention_heads,
|
||||||
|
self.size_per_head))
|
||||||
|
# value_position_scores_r_t is [B, N, F, H]
|
||||||
|
value_position_scores_r_t = self.transpose(value_position_scores_r,
|
||||||
|
self.trans_shape_position)
|
||||||
|
context_layer = context_layer + value_position_scores_r_t
|
||||||
|
|
||||||
|
context_layer = self.transpose(context_layer, self.trans_shape)
|
||||||
|
context_layer = self.reshape(context_layer, self.shape_return)
|
||||||
|
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
class BertSelfAttention(nn.Cell):
|
||||||
|
"""
|
||||||
|
Apply self-attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Batch size of input dataset.
|
||||||
|
seq_length (int): Length of input sequence.
|
||||||
|
hidden_size (int): Size of the bert encoder layers.
|
||||||
|
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||||
|
attention_probs_dropout_prob (float): The dropout probability for
|
||||||
|
BertAttention. Default: 0.1.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||||
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||||
|
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
hidden_size,
|
||||||
|
num_attention_heads=12,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
use_one_hot_embeddings=False,
|
||||||
|
initializer_range=0.02,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
use_relative_positions=False,
|
||||||
|
compute_type=mstype.float32,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
|
super(BertSelfAttention, self).__init__()
|
||||||
|
if hidden_size % num_attention_heads != 0:
|
||||||
|
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||||
|
"of attention heads (%d)" % (hidden_size, num_attention_heads))
|
||||||
|
|
||||||
|
self.size_per_head = int(hidden_size / num_attention_heads)
|
||||||
|
|
||||||
|
self.attention = BertAttention(
|
||||||
|
batch_size=batch_size,
|
||||||
|
from_tensor_width=hidden_size,
|
||||||
|
to_tensor_width=hidden_size,
|
||||||
|
from_seq_length=seq_length,
|
||||||
|
to_seq_length=seq_length,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
size_per_head=self.size_per_head,
|
||||||
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
use_relative_positions=use_relative_positions,
|
||||||
|
has_attention_mask=True,
|
||||||
|
do_return_2d_tensor=True,
|
||||||
|
compute_type=compute_type)
|
||||||
|
|
||||||
|
self.output = BertOutput(in_channels=hidden_size,
|
||||||
|
out_channels=hidden_size,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
dropout_prob=hidden_dropout_prob,
|
||||||
|
compute_type=compute_type,
|
||||||
|
enable_fused_layernorm=enable_fused_layernorm)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.shape = (-1, hidden_size)
|
||||||
|
|
||||||
|
def construct(self, input_tensor, attention_mask):
|
||||||
|
input_tensor = self.reshape(input_tensor, self.shape)
|
||||||
|
attention_output = self.attention(input_tensor, input_tensor, attention_mask)
|
||||||
|
output = self.output(attention_output, input_tensor)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BertEncoderCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Encoder cells used in BertTransformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Batch size of input dataset.
|
||||||
|
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||||
|
seq_length (int): Length of input sequence. Default: 512.
|
||||||
|
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||||
|
intermediate_size (int): Size of intermediate layer. Default: 3072.
|
||||||
|
attention_probs_dropout_prob (float): The dropout probability for
|
||||||
|
BertAttention. Default: 0.02.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||||
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||||
|
hidden_act (str): Activation function. Default: "gelu".
|
||||||
|
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
batch_size,
|
||||||
|
hidden_size=768,
|
||||||
|
seq_length=512,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
attention_probs_dropout_prob=0.02,
|
||||||
|
use_one_hot_embeddings=False,
|
||||||
|
initializer_range=0.02,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
use_relative_positions=False,
|
||||||
|
hidden_act="gelu",
|
||||||
|
compute_type=mstype.float32,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
|
super(BertEncoderCell, self).__init__()
|
||||||
|
self.attention = BertSelfAttention(
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
seq_length=seq_length,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
hidden_dropout_prob=hidden_dropout_prob,
|
||||||
|
use_relative_positions=use_relative_positions,
|
||||||
|
compute_type=compute_type,
|
||||||
|
enable_fused_layernorm=enable_fused_layernorm)
|
||||||
|
self.intermediate = nn.Dense(in_channels=hidden_size,
|
||||||
|
out_channels=intermediate_size,
|
||||||
|
activation=hidden_act,
|
||||||
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||||
|
self.output = BertOutput(in_channels=intermediate_size,
|
||||||
|
out_channels=hidden_size,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
dropout_prob=hidden_dropout_prob,
|
||||||
|
compute_type=compute_type,
|
||||||
|
enable_fused_layernorm=enable_fused_layernorm)
|
||||||
|
|
||||||
|
def construct(self, hidden_states, attention_mask):
|
||||||
|
# self-attention
|
||||||
|
attention_output = self.attention(hidden_states, attention_mask)
|
||||||
|
# feed construct
|
||||||
|
intermediate_output = self.intermediate(attention_output)
|
||||||
|
# add and normalize
|
||||||
|
output = self.output(intermediate_output, attention_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class BertTransformer(nn.Cell):
|
||||||
|
"""
|
||||||
|
Multi-layer bert transformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Batch size of input dataset.
|
||||||
|
hidden_size (int): Size of the encoder layers.
|
||||||
|
seq_length (int): Length of input sequence.
|
||||||
|
num_hidden_layers (int): Number of hidden layers in encoder cells.
|
||||||
|
num_attention_heads (int): Number of attention heads in encoder cells. Default: 12.
|
||||||
|
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072.
|
||||||
|
attention_probs_dropout_prob (float): The dropout probability for
|
||||||
|
BertAttention. Default: 0.1.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1.
|
||||||
|
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||||
|
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
|
||||||
|
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||||
|
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
batch_size,
|
||||||
|
hidden_size,
|
||||||
|
seq_length,
|
||||||
|
num_hidden_layers,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
use_one_hot_embeddings=False,
|
||||||
|
initializer_range=0.02,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
use_relative_positions=False,
|
||||||
|
hidden_act="gelu",
|
||||||
|
compute_type=mstype.float32,
|
||||||
|
return_all_encoders=False,
|
||||||
|
enable_fused_layernorm=False):
|
||||||
|
super(BertTransformer, self).__init__()
|
||||||
|
self.return_all_encoders = return_all_encoders
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for _ in range(num_hidden_layers):
|
||||||
|
layer = BertEncoderCell(batch_size=batch_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
seq_length=seq_length,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
hidden_dropout_prob=hidden_dropout_prob,
|
||||||
|
use_relative_positions=use_relative_positions,
|
||||||
|
hidden_act=hidden_act,
|
||||||
|
compute_type=compute_type,
|
||||||
|
enable_fused_layernorm=enable_fused_layernorm)
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
|
self.layers = nn.CellList(layers)
|
||||||
|
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.shape = (-1, hidden_size)
|
||||||
|
self.out_shape = (batch_size, seq_length, hidden_size)
|
||||||
|
|
||||||
|
def construct(self, input_tensor, attention_mask):
|
||||||
|
prev_output = self.reshape(input_tensor, self.shape)
|
||||||
|
|
||||||
|
all_encoder_layers = ()
|
||||||
|
for layer_module in self.layers:
|
||||||
|
layer_output = layer_module(prev_output, attention_mask)
|
||||||
|
prev_output = layer_output
|
||||||
|
|
||||||
|
if self.return_all_encoders:
|
||||||
|
layer_output = self.reshape(layer_output, self.out_shape)
|
||||||
|
all_encoder_layers = all_encoder_layers + (layer_output,)
|
||||||
|
|
||||||
|
if not self.return_all_encoders:
|
||||||
|
prev_output = self.reshape(prev_output, self.out_shape)
|
||||||
|
all_encoder_layers = all_encoder_layers + (prev_output,)
|
||||||
|
return all_encoder_layers
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAttentionMaskFromInputMask(nn.Cell):
|
||||||
|
"""
|
||||||
|
Create attention mask according to input mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Class): Configuration for BertModel.
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||||
|
self.input_mask_from_dataset = config.input_mask_from_dataset
|
||||||
|
self.input_mask = None
|
||||||
|
|
||||||
|
if not self.input_mask_from_dataset:
|
||||||
|
self.input_mask = initializer(
|
||||||
|
"ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor()
|
||||||
|
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.shape = (config.batch_size, 1, config.seq_length)
|
||||||
|
self.broadcast_ones = initializer(
|
||||||
|
"ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor()
|
||||||
|
self.batch_matmul = P.BatchMatMul()
|
||||||
|
|
||||||
|
def construct(self, input_mask):
|
||||||
|
if not self.input_mask_from_dataset:
|
||||||
|
input_mask = self.input_mask
|
||||||
|
|
||||||
|
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
||||||
|
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
class BertModel(nn.Cell):
|
||||||
|
"""
|
||||||
|
Bidirectional Encoder Representations from Transformers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Class): Configuration for BertModel.
|
||||||
|
is_training (bool): True for training mode. False for eval mode.
|
||||||
|
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
config,
|
||||||
|
is_training,
|
||||||
|
use_one_hot_embeddings=False):
|
||||||
|
super(BertModel, self).__init__()
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
if not is_training:
|
||||||
|
config.hidden_dropout_prob = 0.0
|
||||||
|
config.attention_probs_dropout_prob = 0.0
|
||||||
|
|
||||||
|
self.input_mask_from_dataset = config.input_mask_from_dataset
|
||||||
|
self.token_type_ids_from_dataset = config.token_type_ids_from_dataset
|
||||||
|
self.batch_size = config.batch_size
|
||||||
|
self.seq_length = config.seq_length
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
self.embedding_size = config.hidden_size
|
||||||
|
self.token_type_ids = None
|
||||||
|
|
||||||
|
self.last_idx = self.num_hidden_layers - 1
|
||||||
|
output_embedding_shape = [self.batch_size, self.seq_length,
|
||||||
|
self.embedding_size]
|
||||||
|
|
||||||
|
if not self.token_type_ids_from_dataset:
|
||||||
|
self.token_type_ids = initializer(
|
||||||
|
"zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor()
|
||||||
|
|
||||||
|
self.bert_embedding_lookup = EmbeddingLookup(
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
embedding_size=self.embedding_size,
|
||||||
|
embedding_shape=output_embedding_shape,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
initializer_range=config.initializer_range)
|
||||||
|
|
||||||
|
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||||
|
embedding_size=self.embedding_size,
|
||||||
|
embedding_shape=output_embedding_shape,
|
||||||
|
use_relative_positions=config.use_relative_positions,
|
||||||
|
use_token_type=True,
|
||||||
|
token_type_vocab_size=config.type_vocab_size,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
initializer_range=0.02,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
dropout_prob=config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
self.bert_encoder = BertTransformer(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
seq_length=self.seq_length,
|
||||||
|
num_attention_heads=config.num_attention_heads,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
||||||
|
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||||
|
initializer_range=config.initializer_range,
|
||||||
|
hidden_dropout_prob=config.hidden_dropout_prob,
|
||||||
|
use_relative_positions=config.use_relative_positions,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
compute_type=config.compute_type,
|
||||||
|
return_all_encoders=True,
|
||||||
|
enable_fused_layernorm=config.enable_fused_layernorm)
|
||||||
|
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.dtype = config.dtype
|
||||||
|
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||||
|
self.slice = P.StridedSlice()
|
||||||
|
|
||||||
|
self.squeeze_1 = P.Squeeze(axis=1)
|
||||||
|
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||||
|
activation="tanh",
|
||||||
|
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||||
|
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||||
|
|
||||||
|
def construct(self, input_ids, token_type_ids, input_mask):
|
||||||
|
|
||||||
|
# embedding
|
||||||
|
if not self.token_type_ids_from_dataset:
|
||||||
|
token_type_ids = self.token_type_ids
|
||||||
|
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
|
||||||
|
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
|
||||||
|
word_embeddings)
|
||||||
|
|
||||||
|
# attention mask [batch_size, seq_length, seq_length]
|
||||||
|
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
|
||||||
|
|
||||||
|
# bert encoder
|
||||||
|
encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output),
|
||||||
|
attention_mask)
|
||||||
|
|
||||||
|
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
||||||
|
|
||||||
|
# pooler
|
||||||
|
sequence_slice = self.slice(sequence_output,
|
||||||
|
(0, 0, 0),
|
||||||
|
(self.batch_size, 1, self.hidden_size),
|
||||||
|
(1, 1, 1))
|
||||||
|
first_token = self.squeeze_1(sequence_slice)
|
||||||
|
pooled_output = self.dense(first_token)
|
||||||
|
pooled_output = self.cast(pooled_output, self.dtype)
|
||||||
|
|
||||||
|
return sequence_output, pooled_output, embedding_tables
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
'''bert clue evaluation'''
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
import tokenization
|
||||||
|
from sample_process import label_generation, process_one_example_p
|
||||||
|
from .evaluation_config import cfg
|
||||||
|
from .CRF import postprocess
|
||||||
|
|
||||||
|
vocab_file = "./vocab.txt"
|
||||||
|
tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file)
|
||||||
|
|
||||||
|
def process(model, text, sequence_length):
|
||||||
|
"""
|
||||||
|
process text.
|
||||||
|
"""
|
||||||
|
data = [text]
|
||||||
|
features = []
|
||||||
|
res = []
|
||||||
|
ids = []
|
||||||
|
for i in data:
|
||||||
|
feature = process_one_example_p(tokenizer_, i, max_seq_len=sequence_length)
|
||||||
|
features.append(feature)
|
||||||
|
input_ids, input_mask, token_type_id = feature
|
||||||
|
input_ids = Tensor(np.array(input_ids), mstype.int32)
|
||||||
|
input_mask = Tensor(np.array(input_mask), mstype.int32)
|
||||||
|
token_type_id = Tensor(np.array(token_type_id), mstype.int32)
|
||||||
|
if cfg.use_crf:
|
||||||
|
backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
|
||||||
|
best_path = postprocess(backpointers, best_tag_id)
|
||||||
|
logits = []
|
||||||
|
for ele in best_path:
|
||||||
|
logits.extend(ele)
|
||||||
|
ids = logits
|
||||||
|
else:
|
||||||
|
logits = model.predict(input_ids, input_mask, token_type_id, Tensor(1))
|
||||||
|
ids = logits.asnumpy()
|
||||||
|
ids = np.argmax(ids, axis=-1)
|
||||||
|
ids = list(ids)
|
||||||
|
res = label_generation(text, ids)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def submit(model, path, sequence_length):
|
||||||
|
"""
|
||||||
|
submit task
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for line in open(path):
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
oneline = json.loads(line.strip())
|
||||||
|
res = process(model, oneline["text"], sequence_length)
|
||||||
|
print("text", oneline["text"])
|
||||||
|
print("res:", res)
|
||||||
|
data.append(json.dumps({"label": res}, ensure_ascii=False))
|
||||||
|
open("ner_predict.json", "w").write("\n".join(data))
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in dataset.py, run_pretrain.py
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from .bert_model import BertConfig
|
||||||
|
cfg = edict({
|
||||||
|
'bert_network': 'base',
|
||||||
|
'loss_scale_value': 65536,
|
||||||
|
'scale_factor': 2,
|
||||||
|
'scale_window': 1000,
|
||||||
|
'optimizer': 'Lamb',
|
||||||
|
'AdamWeightDecayDynamicLR': edict({
|
||||||
|
'learning_rate': 3e-5,
|
||||||
|
'end_learning_rate': 1e-10,
|
||||||
|
'power': 5.0,
|
||||||
|
'weight_decay': 1e-5,
|
||||||
|
'eps': 1e-6,
|
||||||
|
'warmup_steps': 10000,
|
||||||
|
}),
|
||||||
|
'Lamb': edict({
|
||||||
|
'start_learning_rate': 3e-5,
|
||||||
|
'end_learning_rate': 1e-10,
|
||||||
|
'power': 10.0,
|
||||||
|
'warmup_steps': 10000,
|
||||||
|
'weight_decay': 0.01,
|
||||||
|
'eps': 1e-6,
|
||||||
|
}),
|
||||||
|
'Momentum': edict({
|
||||||
|
'learning_rate': 2e-5,
|
||||||
|
'momentum': 0.9,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
'''
|
||||||
|
Including two kinds of network: \
|
||||||
|
base: Goole BERT-base(the base version of BERT model).
|
||||||
|
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
|
||||||
|
Functional Relative Posetional Encoding as an effective positional encoding scheme).
|
||||||
|
'''
|
||||||
|
if cfg.bert_network == 'base':
|
||||||
|
bert_net_cfg = BertConfig(
|
||||||
|
batch_size=32,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=21136,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16
|
||||||
|
)
|
||||||
|
if cfg.bert_network == 'nezha':
|
||||||
|
bert_net_cfg = BertConfig(
|
||||||
|
batch_size=32,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=21136,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=16,
|
||||||
|
intermediate_size=4096,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=True,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16
|
||||||
|
)
|
||||||
|
if cfg.bert_network == 'large':
|
||||||
|
bert_net_cfg = BertConfig(
|
||||||
|
batch_size=16,
|
||||||
|
seq_length=512,
|
||||||
|
vocab_size=30528,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=16,
|
||||||
|
intermediate_size=4096,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16,
|
||||||
|
enable_fused_layernorm=True
|
||||||
|
)
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Data operations, will be used in run_pretrain.py
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine.datasets as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
|
from mindspore import log as logger
|
||||||
|
from .config import bert_net_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", enable_data_sink="true",
|
||||||
|
data_sink_steps=1, data_dir=None, schema_dir=None):
|
||||||
|
"""create train dataset"""
|
||||||
|
# apply repeat operations
|
||||||
|
repeat_count = epoch_size
|
||||||
|
files = os.listdir(data_dir)
|
||||||
|
data_files = []
|
||||||
|
for file_name in files:
|
||||||
|
if "tfrecord" in file_name:
|
||||||
|
data_files.append(os.path.join(data_dir, file_name))
|
||||||
|
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
|
||||||
|
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||||
|
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
|
||||||
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||||
|
shard_equal_rows=True)
|
||||||
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
|
new_size = ori_dataset_size
|
||||||
|
if enable_data_sink == "true":
|
||||||
|
new_size = data_sink_steps * bert_net_cfg.batch_size
|
||||||
|
ds.set_dataset_size(new_size)
|
||||||
|
new_repeat_count = int(repeat_count * ori_dataset_size // ds.get_dataset_size())
|
||||||
|
type_cast_op = C.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||||
|
# apply batch operations
|
||||||
|
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||||
|
ds = ds.repeat(new_repeat_count)
|
||||||
|
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||||
|
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
||||||
|
return ds, new_repeat_count
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
config settings, will be used in finetune.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from .bert_model import BertConfig
|
||||||
|
|
||||||
|
cfg = edict({
|
||||||
|
'task': 'NER',
|
||||||
|
'num_labels': 41,
|
||||||
|
'data_file': '/your/path/evaluation.tfrecord',
|
||||||
|
'schema_file': '/your/path/schema.json',
|
||||||
|
'finetune_ckpt': '/your/path/your.ckpt',
|
||||||
|
'use_crf': False,
|
||||||
|
'clue_benchmark': False,
|
||||||
|
})
|
||||||
|
|
||||||
|
bert_net_cfg = BertConfig(
|
||||||
|
batch_size=16 if not cfg.clue_benchmark else 1,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=21128,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.0,
|
||||||
|
attention_probs_dropout_prob=0.0,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16,
|
||||||
|
)
|
|
@ -0,0 +1,119 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
config settings, will be used in finetune.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from .bert_model import BertConfig
|
||||||
|
|
||||||
|
cfg = edict({
|
||||||
|
'task': 'NER',
|
||||||
|
'num_labels': 41,
|
||||||
|
'data_file': '/your/path/train.tfrecord',
|
||||||
|
'schema_file': '/your/path/schema.json',
|
||||||
|
'epoch_num': 5,
|
||||||
|
'ckpt_prefix': 'bert',
|
||||||
|
'ckpt_dir': None,
|
||||||
|
'pre_training_ckpt': '/your/path/pre_training.ckpt',
|
||||||
|
'use_crf': False,
|
||||||
|
'optimizer': 'Lamb',
|
||||||
|
'AdamWeightDecayDynamicLR': edict({
|
||||||
|
'learning_rate': 2e-5,
|
||||||
|
'end_learning_rate': 1e-7,
|
||||||
|
'power': 1.0,
|
||||||
|
'weight_decay': 1e-5,
|
||||||
|
'eps': 1e-6,
|
||||||
|
}),
|
||||||
|
'Lamb': edict({
|
||||||
|
'start_learning_rate': 2e-5,
|
||||||
|
'end_learning_rate': 1e-7,
|
||||||
|
'power': 1.0,
|
||||||
|
'decay_filter': lambda x: False,
|
||||||
|
}),
|
||||||
|
'Momentum': edict({
|
||||||
|
'learning_rate': 2e-5,
|
||||||
|
'momentum': 0.9,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
bert_net_cfg = BertConfig(
|
||||||
|
batch_size=16,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=21128,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
use_relative_positions=False,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
token_type_ids_from_dataset=True,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
compute_type=mstype.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
tag_to_index = {
|
||||||
|
"O": 0,
|
||||||
|
"S_address": 1,
|
||||||
|
"B_address": 2,
|
||||||
|
"M_address": 3,
|
||||||
|
"E_address": 4,
|
||||||
|
"S_book": 5,
|
||||||
|
"B_book": 6,
|
||||||
|
"M_book": 7,
|
||||||
|
"E_book": 8,
|
||||||
|
"S_company": 9,
|
||||||
|
"B_company": 10,
|
||||||
|
"M_company": 11,
|
||||||
|
"E_company": 12,
|
||||||
|
"S_game": 13,
|
||||||
|
"B_game": 14,
|
||||||
|
"M_game": 15,
|
||||||
|
"E_game": 16,
|
||||||
|
"S_government": 17,
|
||||||
|
"B_government": 18,
|
||||||
|
"M_government": 19,
|
||||||
|
"E_government": 20,
|
||||||
|
"S_movie": 21,
|
||||||
|
"B_movie": 22,
|
||||||
|
"M_movie": 23,
|
||||||
|
"E_movie": 24,
|
||||||
|
"S_name": 25,
|
||||||
|
"B_name": 26,
|
||||||
|
"M_name": 27,
|
||||||
|
"E_name": 28,
|
||||||
|
"S_organization": 29,
|
||||||
|
"B_organization": 30,
|
||||||
|
"M_organization": 31,
|
||||||
|
"E_organization": 32,
|
||||||
|
"S_position": 33,
|
||||||
|
"B_position": 34,
|
||||||
|
"M_position": 35,
|
||||||
|
"E_position": 36,
|
||||||
|
"S_scene": 37,
|
||||||
|
"B_scene": 38,
|
||||||
|
"M_scene": 39,
|
||||||
|
"E_scene": 40,
|
||||||
|
"<START>": 41,
|
||||||
|
"<STOP>": 42
|
||||||
|
}
|
|
@ -0,0 +1,121 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""fused layernorm"""
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.nn.cell import Cell
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['FusedLayerNorm']
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def get_shape_for_norm(x_shape, begin_norm_axis):
|
||||||
|
print("input_shape: ", x_shape)
|
||||||
|
norm_shape = x_shape[begin_norm_axis:]
|
||||||
|
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
|
||||||
|
print("output_shape: ", output_shape)
|
||||||
|
return output_shape
|
||||||
|
|
||||||
|
class FusedLayerNorm(Cell):
|
||||||
|
r"""
|
||||||
|
Applies Layer Normalization over a mini-batch of inputs.
|
||||||
|
|
||||||
|
Layer normalization is widely used in recurrent neural networks. It applies
|
||||||
|
normalization over a mini-batch of inputs for each single training case as described
|
||||||
|
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
|
||||||
|
normalization, layer normalization performs exactly the same computation at training and
|
||||||
|
testing times. It can be described using the following formula. It is applied across all channels
|
||||||
|
and pixel but only one batch size.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
|
||||||
|
`begin_norm_axis ... R - 1`.
|
||||||
|
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
|
||||||
|
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
|
||||||
|
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
|
||||||
|
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
|
||||||
|
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
|
||||||
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||||
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||||
|
'he_uniform', etc. Default: 'ones'.
|
||||||
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||||
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||||
|
'he_uniform', etc. Default: 'zeros'.
|
||||||
|
use_batch_nrom (bool): Whether use batchnorm to preocess.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
|
||||||
|
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
|
||||||
|
>>> shape1 = x.shape()[1:]
|
||||||
|
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
|
||||||
|
>>> m(x)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
normalized_shape,
|
||||||
|
begin_norm_axis=-1,
|
||||||
|
begin_params_axis=-1,
|
||||||
|
gamma_init='ones',
|
||||||
|
beta_init='zeros',
|
||||||
|
use_batch_norm=False):
|
||||||
|
super(FusedLayerNorm, self).__init__()
|
||||||
|
if not isinstance(normalized_shape, (tuple, list)):
|
||||||
|
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
|
||||||
|
.format(normalized_shape, type(normalized_shape)))
|
||||||
|
self.normalized_shape = normalized_shape
|
||||||
|
self.begin_norm_axis = begin_norm_axis
|
||||||
|
self.begin_params_axis = begin_params_axis
|
||||||
|
self.gamma = Parameter(initializer(
|
||||||
|
gamma_init, normalized_shape), name="gamma")
|
||||||
|
self.beta = Parameter(initializer(
|
||||||
|
beta_init, normalized_shape), name="beta")
|
||||||
|
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
|
||||||
|
|
||||||
|
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
if self.use_batch_norm and self.training:
|
||||||
|
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
|
||||||
|
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
|
||||||
|
shape_x = F.shape(input_x)
|
||||||
|
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
|
||||||
|
input_x = F.reshape(input_x, norm_shape)
|
||||||
|
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
|
||||||
|
output = F.reshape(output, shape_x)
|
||||||
|
y = output * self.gamma + self.beta
|
||||||
|
else:
|
||||||
|
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
"""Display instance object as string."""
|
||||||
|
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
|
||||||
|
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
|
||||||
|
return s
|
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""process txt"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
|
||||||
|
def process_one_example_p(tokenizer, text, max_seq_len=128):
|
||||||
|
"""process one testline"""
|
||||||
|
textlist = list(text)
|
||||||
|
tokens = []
|
||||||
|
for _, word in enumerate(textlist):
|
||||||
|
token = tokenizer.tokenize(word)
|
||||||
|
tokens.extend(token)
|
||||||
|
if len(tokens) >= max_seq_len - 1:
|
||||||
|
tokens = tokens[0:(max_seq_len - 2)]
|
||||||
|
ntokens = []
|
||||||
|
segment_ids = []
|
||||||
|
label_ids = []
|
||||||
|
ntokens.append("[CLS]")
|
||||||
|
segment_ids.append(0)
|
||||||
|
for _, token in enumerate(tokens):
|
||||||
|
ntokens.append(token)
|
||||||
|
segment_ids.append(0)
|
||||||
|
ntokens.append("[SEP]")
|
||||||
|
segment_ids.append(0)
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
||||||
|
input_mask = [1] * len(input_ids)
|
||||||
|
while len(input_ids) < max_seq_len:
|
||||||
|
input_ids.append(0)
|
||||||
|
input_mask.append(0)
|
||||||
|
segment_ids.append(0)
|
||||||
|
label_ids.append(0)
|
||||||
|
ntokens.append("**NULL**")
|
||||||
|
assert len(input_ids) == max_seq_len
|
||||||
|
assert len(input_mask) == max_seq_len
|
||||||
|
assert len(segment_ids) == max_seq_len
|
||||||
|
|
||||||
|
feature = (input_ids, input_mask, segment_ids)
|
||||||
|
return feature
|
||||||
|
|
||||||
|
def label_generation(text, probs):
|
||||||
|
"""generate label"""
|
||||||
|
data = [text]
|
||||||
|
probs = [probs]
|
||||||
|
result = []
|
||||||
|
label2id = json.loads(open("./label2id.json").read())
|
||||||
|
id2label = [k for k, v in label2id.items()]
|
||||||
|
|
||||||
|
for index, prob in enumerate(probs):
|
||||||
|
for v in prob[1:len(data[index]) + 1]:
|
||||||
|
result.append(id2label[int(v)])
|
||||||
|
|
||||||
|
labels = {}
|
||||||
|
start = None
|
||||||
|
index = 0
|
||||||
|
for _, t in zip("".join(data), result):
|
||||||
|
if re.search("^[BS]", t):
|
||||||
|
if start is not None:
|
||||||
|
label = result[index - 1][2:]
|
||||||
|
if labels.get(label):
|
||||||
|
te_ = text[start:index]
|
||||||
|
labels[label][te_] = [[start, index - 1]]
|
||||||
|
else:
|
||||||
|
te_ = text[start:index]
|
||||||
|
labels[label] = {te_: [[start, index - 1]]}
|
||||||
|
start = index
|
||||||
|
if re.search("^O", t):
|
||||||
|
if start is not None:
|
||||||
|
label = result[index - 1][2:]
|
||||||
|
if labels.get(label):
|
||||||
|
te_ = text[start:index]
|
||||||
|
labels[label][te_] = [[start, index - 1]]
|
||||||
|
else:
|
||||||
|
te_ = text[start:index]
|
||||||
|
labels[label] = {te_: [[start, index - 1]]}
|
||||||
|
start = None
|
||||||
|
index += 1
|
||||||
|
if start is not None:
|
||||||
|
label = result[start][2:]
|
||||||
|
if labels.get(label):
|
||||||
|
te_ = text[start:index]
|
||||||
|
labels[label][te_] = [[start, index - 1]]
|
||||||
|
else:
|
||||||
|
te_ = text[start:index]
|
||||||
|
labels[label] = {te_: [[start, index - 1]]}
|
||||||
|
return labels
|
|
@ -0,0 +1,263 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
'''
|
||||||
|
Functional Cells used in Bert finetune and evaluation.
|
||||||
|
'''
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.initializer import TruncatedNormal
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||||
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
from mindspore.communication.management import get_group_size
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
|
||||||
|
from .bert_for_pre_training import clip_grad
|
||||||
|
from .CRF import CRF
|
||||||
|
|
||||||
|
GRADIENT_CLIP_TYPE = 1
|
||||||
|
GRADIENT_CLIP_VALUE = 1.0
|
||||||
|
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||||
|
reciprocal = P.Reciprocal()
|
||||||
|
|
||||||
|
@grad_scale.register("Tensor", "Tensor")
|
||||||
|
def tensor_grad_scale(scale, grad):
|
||||||
|
return grad * reciprocal(scale)
|
||||||
|
|
||||||
|
class BertFinetuneCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Especifically defined for finetuning where only four inputs tensor are needed.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||||
|
|
||||||
|
super(BertFinetuneCell, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.weights = ParameterTuple(network.trainable_params())
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation('grad',
|
||||||
|
get_by_list=True,
|
||||||
|
sens_param=True)
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.allreduce = P.AllReduce()
|
||||||
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
|
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||||
|
self.reducer_flag = True
|
||||||
|
self.grad_reducer = None
|
||||||
|
if self.reducer_flag:
|
||||||
|
mean = context.get_auto_parallel_context("mirror_mean")
|
||||||
|
degree = get_group_size()
|
||||||
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
|
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.alloc_status = P.NPUAllocFloatStatus()
|
||||||
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||||
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||||
|
self.base = Tensor(1, mstype.float32)
|
||||||
|
self.less_equal = P.LessEqual()
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.loss_scale = None
|
||||||
|
self.loss_scaling_manager = scale_update_cell
|
||||||
|
if scale_update_cell:
|
||||||
|
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||||
|
name="loss_scale")
|
||||||
|
|
||||||
|
def construct(self,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
label_ids,
|
||||||
|
sens=None):
|
||||||
|
|
||||||
|
|
||||||
|
weights = self.weights
|
||||||
|
init = self.alloc_status()
|
||||||
|
loss = self.network(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
label_ids)
|
||||||
|
if sens is None:
|
||||||
|
scaling_sens = self.loss_scale
|
||||||
|
else:
|
||||||
|
scaling_sens = sens
|
||||||
|
grads = self.grad(self.network, weights)(input_ids,
|
||||||
|
input_mask,
|
||||||
|
token_type_id,
|
||||||
|
label_ids,
|
||||||
|
self.cast(scaling_sens,
|
||||||
|
mstype.float32))
|
||||||
|
clear_before_grad = self.clear_before_grad(init)
|
||||||
|
F.control_depend(loss, init)
|
||||||
|
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
||||||
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
|
||||||
|
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||||
|
if self.reducer_flag:
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
flag = self.get_status(init)
|
||||||
|
flag_sum = self.reduce_sum(init, (0,))
|
||||||
|
if self.is_distributed:
|
||||||
|
flag_reduce = self.allreduce(flag_sum)
|
||||||
|
cond = self.less_equal(self.base, flag_reduce)
|
||||||
|
else:
|
||||||
|
cond = self.less_equal(self.base, flag_sum)
|
||||||
|
F.control_depend(grads, flag)
|
||||||
|
F.control_depend(flag, flag_sum)
|
||||||
|
overflow = cond
|
||||||
|
if sens is None:
|
||||||
|
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||||
|
if overflow:
|
||||||
|
succ = False
|
||||||
|
else:
|
||||||
|
succ = self.optimizer(grads)
|
||||||
|
ret = (loss, cond)
|
||||||
|
return F.depend(ret, succ)
|
||||||
|
|
||||||
|
class BertCLSModel(nn.Cell):
|
||||||
|
"""
|
||||||
|
This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3),
|
||||||
|
LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final
|
||||||
|
logits as the results of log_softmax is propotional to that of softmax.
|
||||||
|
"""
|
||||||
|
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||||
|
super(BertCLSModel, self).__init__()
|
||||||
|
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||||
|
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||||
|
self.dtype = config.dtype
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||||
|
has_bias=True).to_float(config.compute_type)
|
||||||
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||||
|
|
||||||
|
def construct(self, input_ids, input_mask, token_type_id):
|
||||||
|
_, pooled_output, _ = \
|
||||||
|
self.bert(input_ids, token_type_id, input_mask)
|
||||||
|
cls = self.cast(pooled_output, self.dtype)
|
||||||
|
cls = self.dropout(cls)
|
||||||
|
logits = self.dense_1(cls)
|
||||||
|
logits = self.cast(logits, self.dtype)
|
||||||
|
log_probs = self.log_softmax(logits)
|
||||||
|
return log_probs
|
||||||
|
|
||||||
|
|
||||||
|
class BertNERModel(nn.Cell):
|
||||||
|
"""
|
||||||
|
This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11).
|
||||||
|
The returned output represents the final logits as the results of log_softmax is propotional to that of softmax.
|
||||||
|
"""
|
||||||
|
def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0,
|
||||||
|
use_one_hot_embeddings=False):
|
||||||
|
super(BertNERModel, self).__init__()
|
||||||
|
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.weight_init = TruncatedNormal(config.initializer_range)
|
||||||
|
self.log_softmax = P.LogSoftmax(axis=-1)
|
||||||
|
self.dtype = config.dtype
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
|
||||||
|
has_bias=True).to_float(config.compute_type)
|
||||||
|
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.shape = (-1, config.hidden_size)
|
||||||
|
self.use_crf = use_crf
|
||||||
|
self.origin_shape = (config.batch_size, config.seq_length, self.num_labels)
|
||||||
|
|
||||||
|
def construct(self, input_ids, input_mask, token_type_id):
|
||||||
|
sequence_output, _, _ = \
|
||||||
|
self.bert(input_ids, token_type_id, input_mask)
|
||||||
|
seq = self.dropout(sequence_output)
|
||||||
|
seq = self.reshape(seq, self.shape)
|
||||||
|
logits = self.dense_1(seq)
|
||||||
|
logits = self.cast(logits, self.dtype)
|
||||||
|
if self.use_crf:
|
||||||
|
return_value = self.reshape(logits, self.origin_shape)
|
||||||
|
else:
|
||||||
|
return_value = self.log_softmax(logits)
|
||||||
|
return return_value
|
||||||
|
|
||||||
|
class CrossEntropyCalculation(nn.Cell):
|
||||||
|
"""
|
||||||
|
Cross Entropy loss
|
||||||
|
"""
|
||||||
|
def __init__(self, is_training=True):
|
||||||
|
super(CrossEntropyCalculation, self).__init__()
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
|
self.off_value = Tensor(0.0, mstype.float32)
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.reduce_mean = P.ReduceMean()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.last_idx = (-1,)
|
||||||
|
self.neg = P.Neg()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
def construct(self, logits, label_ids, num_labels):
|
||||||
|
if self.is_training:
|
||||||
|
label_ids = self.reshape(label_ids, self.last_idx)
|
||||||
|
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
|
||||||
|
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
|
||||||
|
loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||||
|
return_value = self.cast(loss, mstype.float32)
|
||||||
|
else:
|
||||||
|
return_value = logits * 1.0
|
||||||
|
return return_value
|
||||||
|
|
||||||
|
class BertCLS(nn.Cell):
|
||||||
|
"""
|
||||||
|
Train interface for classification finetuning task.
|
||||||
|
"""
|
||||||
|
def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False):
|
||||||
|
super(BertCLS, self).__init__()
|
||||||
|
self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings)
|
||||||
|
self.loss = CrossEntropyCalculation(is_training)
|
||||||
|
self.num_labels = num_labels
|
||||||
|
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||||
|
log_probs = self.bert(input_ids, input_mask, token_type_id)
|
||||||
|
loss = self.loss(log_probs, label_ids, self.num_labels)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class BertNER(nn.Cell):
|
||||||
|
"""
|
||||||
|
Train interface for sequence labeling finetuning task.
|
||||||
|
"""
|
||||||
|
def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0,
|
||||||
|
use_one_hot_embeddings=False):
|
||||||
|
super(BertNER, self).__init__()
|
||||||
|
self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings)
|
||||||
|
if use_crf:
|
||||||
|
if not tag_to_index:
|
||||||
|
raise Exception("The dict for tag-index mapping should be provided for CRF.")
|
||||||
|
self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training)
|
||||||
|
else:
|
||||||
|
self.loss = CrossEntropyCalculation(is_training)
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.use_crf = use_crf
|
||||||
|
def construct(self, input_ids, input_mask, token_type_id, label_ids):
|
||||||
|
logits = self.bert(input_ids, input_mask, token_type_id)
|
||||||
|
if self.use_crf:
|
||||||
|
loss = self.loss(logits, label_ids)
|
||||||
|
else:
|
||||||
|
loss = self.loss(logits, label_ids, self.num_labels)
|
||||||
|
return loss
|
|
@ -1,52 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
""" test bert cell """
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertModel
|
|
||||||
from ....dataset_mock import MindData
|
|
||||||
|
|
||||||
|
|
||||||
def map_bert(record):
|
|
||||||
target_data = {'input_ids': None, 'input_mask': None,
|
|
||||||
'segment_ids': None, 'next_sentence_labels': None,
|
|
||||||
'masked_lm_positions': None, 'masked_lm_ids': None,
|
|
||||||
'masked_lm_weights': None}
|
|
||||||
|
|
||||||
sample = dt.parse_single_example(record, target_data)
|
|
||||||
|
|
||||||
return sample['input_ids'], sample['input_mask'], sample['segment_ids'], \
|
|
||||||
sample['next_sentence_labels'], sample['masked_lm_positions'], \
|
|
||||||
sample['masked_lm_ids'], sample['masked_lm_weights']
|
|
||||||
|
|
||||||
|
|
||||||
def test_bert_model():
|
|
||||||
# test for config.hidden_size % config.num_attention_heads != 0
|
|
||||||
config_error = BertConfig(32, hidden_size=512, num_attention_heads=10)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
BertModel(config_error, True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(batch_size=1):
|
|
||||||
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
|
|
||||||
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
|
|
||||||
(batch_size, 20), (batch_size, 20), (batch_size, 20))
|
|
||||||
|
|
||||||
dataset = MindData(size=2, batch_size=batch_size,
|
|
||||||
np_types=dataset_types,
|
|
||||||
output_shapes=dataset_shapes,
|
|
||||||
input_indexs=(0, 1))
|
|
||||||
return dataset
|
|
|
@ -1,437 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
""" test bert of graph compile """
|
|
||||||
import functools
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
import mindspore.nn as nn
|
|
||||||
import mindspore.ops.composite as C
|
|
||||||
from mindspore.ops import functional as F
|
|
||||||
from mindspore.common.initializer import TruncatedNormal
|
|
||||||
from mindspore.common.parameter import ParameterTuple
|
|
||||||
from mindspore.common.tensor import Tensor
|
|
||||||
from mindspore.model_zoo.Bert_NEZHA import BertPretrainingLoss, GetNextSentenceOutput
|
|
||||||
from mindspore.model_zoo.Bert_NEZHA.bert_for_pre_training import clip_grad
|
|
||||||
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertConfig, \
|
|
||||||
EmbeddingLookup, EmbeddingPostprocessor, BertOutput, RelaPosMatrixGenerator, \
|
|
||||||
RelaPosEmbeddingsGenerator, SaturateCast, BertAttention, BertSelfAttention, \
|
|
||||||
BertEncoderCell, BertTransformer, CreateAttentionMaskFromInputMask, BertModel
|
|
||||||
from mindspore.nn.layer.basic import Norm
|
|
||||||
from mindspore.nn.optim import AdamWeightDecay, AdamWeightDecayDynamicLR
|
|
||||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
|
||||||
from ....mindspore_test_framework.pipeline.forward.compile_forward import \
|
|
||||||
pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
|
||||||
from ....mindspore_test_framework.pipeline.gradient.compile_gradient import \
|
|
||||||
pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
|
||||||
from ....ops_common import convert
|
|
||||||
|
|
||||||
|
|
||||||
def bert_trans():
|
|
||||||
"""bert_trans"""
|
|
||||||
net = BertTransformer(batch_size=1,
|
|
||||||
hidden_size=768,
|
|
||||||
seq_length=128,
|
|
||||||
num_hidden_layers=1,
|
|
||||||
num_attention_heads=12,
|
|
||||||
intermediate_size=768,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
use_relative_positions=False,
|
|
||||||
hidden_act="gelu",
|
|
||||||
compute_type=mstype.float32,
|
|
||||||
return_all_encoders=True)
|
|
||||||
net.set_train()
|
|
||||||
return net
|
|
||||||
|
|
||||||
|
|
||||||
def set_train(net):
|
|
||||||
net.set_train()
|
|
||||||
return net
|
|
||||||
|
|
||||||
|
|
||||||
class NetForAdam(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(NetForAdam, self).__init__()
|
|
||||||
self.dense = nn.Dense(64, 10)
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
x = self.dense(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TrainStepWrapForAdam(nn.Cell):
|
|
||||||
"""TrainStepWrapForAdam definition"""
|
|
||||||
|
|
||||||
def __init__(self, network):
|
|
||||||
super(TrainStepWrapForAdam, self).__init__()
|
|
||||||
self.network = network
|
|
||||||
self.weights = ParameterTuple(network.get_parameters())
|
|
||||||
self.optimizer = AdamWeightDecay(self.weights)
|
|
||||||
self.hyper_map = C.HyperMap()
|
|
||||||
|
|
||||||
def construct(self, x, sens):
|
|
||||||
weights = self.weights
|
|
||||||
grads = C.grad_by_list_with_sens(self.network, weights)(x, sens)
|
|
||||||
grads = self.hyper_map(F.partial(clip_grad, 1, 1.0), grads)
|
|
||||||
return self.optimizer(grads)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainStepWrapForAdamDynamicLr(nn.Cell):
|
|
||||||
"""TrainStepWrapForAdamDynamicLr definition"""
|
|
||||||
|
|
||||||
def __init__(self, network):
|
|
||||||
super(TrainStepWrapForAdamDynamicLr, self).__init__()
|
|
||||||
self.network = network
|
|
||||||
self.weights = ParameterTuple(network.get_parameters())
|
|
||||||
self.optimizer = AdamWeightDecayDynamicLR(self.weights, 10)
|
|
||||||
self.sens = Tensor(np.ones(shape=(1, 10)).astype(np.float32))
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
weights = self.weights
|
|
||||||
grads = C.grad_by_list_with_sens(self.network, weights)(x, self.sens)
|
|
||||||
return self.optimizer(grads)
|
|
||||||
|
|
||||||
|
|
||||||
class TempC2Wrap(nn.Cell):
|
|
||||||
def __init__(self, op, c1=None, c2=None,):
|
|
||||||
super(TempC2Wrap, self).__init__()
|
|
||||||
self.op = op
|
|
||||||
self.c1 = c1
|
|
||||||
self.c2 = c2
|
|
||||||
self.hyper_map = C.HyperMap()
|
|
||||||
|
|
||||||
def construct(self, x1):
|
|
||||||
x = self.hyper_map(F.partial(self.op, self.c1, self.c2), x1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
test_case_cell_ops = [
|
|
||||||
('Norm_keepdims', {
|
|
||||||
'block': Norm(keep_dims=True),
|
|
||||||
'desc_inputs': [[1, 3, 4, 4]],
|
|
||||||
'desc_bprop': [[1]]}),
|
|
||||||
('SaturateCast', {
|
|
||||||
'block': SaturateCast(),
|
|
||||||
'desc_inputs': [[1, 3, 4, 4]],
|
|
||||||
'desc_bprop': [[1, 3, 4, 4]]}),
|
|
||||||
('RelaPosMatrixGenerator_0', {
|
|
||||||
'block': RelaPosMatrixGenerator(length=128, max_relative_position=16),
|
|
||||||
'desc_inputs': [],
|
|
||||||
'desc_bprop': [[128, 128]],
|
|
||||||
'skip': ['backward']}),
|
|
||||||
('RelaPosEmbeddingsGenerator_0', {
|
|
||||||
'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
|
|
||||||
max_relative_position=16,
|
|
||||||
initializer_range=0.2),
|
|
||||||
'desc_inputs': [],
|
|
||||||
'desc_bprop': [[16384, 512]],
|
|
||||||
'skip': ['backward']}),
|
|
||||||
('RelaPosEmbeddingsGenerator_1', {
|
|
||||||
'block': RelaPosEmbeddingsGenerator(length=128, depth=512,
|
|
||||||
max_relative_position=16,
|
|
||||||
initializer_range=0.2,
|
|
||||||
use_one_hot_embeddings=False),
|
|
||||||
'desc_inputs': [],
|
|
||||||
'desc_bprop': [[128, 128, 512]],
|
|
||||||
'skip': ['backward']}),
|
|
||||||
('RelaPosEmbeddingsGenerator_2', {
|
|
||||||
'block': RelaPosEmbeddingsGenerator(length=128, depth=64,
|
|
||||||
max_relative_position=16,
|
|
||||||
initializer_range=0.2,
|
|
||||||
use_one_hot_embeddings=False),
|
|
||||||
'desc_inputs': [],
|
|
||||||
'desc_bprop': [[128, 128, 64]],
|
|
||||||
'skip': ['backward']}),
|
|
||||||
('BertAttention_0', {
|
|
||||||
'block': BertAttention(batch_size=64,
|
|
||||||
from_tensor_width=768,
|
|
||||||
to_tensor_width=768,
|
|
||||||
from_seq_length=128,
|
|
||||||
to_seq_length=128,
|
|
||||||
num_attention_heads=12,
|
|
||||||
size_per_head=64,
|
|
||||||
query_act=None,
|
|
||||||
key_act=None,
|
|
||||||
value_act=None,
|
|
||||||
has_attention_mask=True,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
do_return_2d_tensor=True,
|
|
||||||
use_relative_positions=False,
|
|
||||||
compute_type=mstype.float32),
|
|
||||||
'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
|
|
||||||
'desc_bprop': [[8192, 768]]}),
|
|
||||||
('BertAttention_1', {
|
|
||||||
'block': BertAttention(batch_size=64,
|
|
||||||
from_tensor_width=768,
|
|
||||||
to_tensor_width=768,
|
|
||||||
from_seq_length=128,
|
|
||||||
to_seq_length=128,
|
|
||||||
num_attention_heads=12,
|
|
||||||
size_per_head=64,
|
|
||||||
query_act=None,
|
|
||||||
key_act=None,
|
|
||||||
value_act=None,
|
|
||||||
has_attention_mask=True,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
do_return_2d_tensor=True,
|
|
||||||
use_relative_positions=True,
|
|
||||||
compute_type=mstype.float32),
|
|
||||||
'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
|
|
||||||
'desc_bprop': [[8192, 768]]}),
|
|
||||||
('BertAttention_2', {
|
|
||||||
'block': BertAttention(batch_size=64,
|
|
||||||
from_tensor_width=768,
|
|
||||||
to_tensor_width=768,
|
|
||||||
from_seq_length=128,
|
|
||||||
to_seq_length=128,
|
|
||||||
num_attention_heads=12,
|
|
||||||
size_per_head=64,
|
|
||||||
query_act=None,
|
|
||||||
key_act=None,
|
|
||||||
value_act=None,
|
|
||||||
has_attention_mask=False,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
do_return_2d_tensor=True,
|
|
||||||
use_relative_positions=True,
|
|
||||||
compute_type=mstype.float32),
|
|
||||||
'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
|
|
||||||
'desc_bprop': [[8192, 768]]}),
|
|
||||||
('BertAttention_3', {
|
|
||||||
'block': BertAttention(batch_size=64,
|
|
||||||
from_tensor_width=768,
|
|
||||||
to_tensor_width=768,
|
|
||||||
from_seq_length=128,
|
|
||||||
to_seq_length=128,
|
|
||||||
num_attention_heads=12,
|
|
||||||
size_per_head=64,
|
|
||||||
query_act=None,
|
|
||||||
key_act=None,
|
|
||||||
value_act=None,
|
|
||||||
has_attention_mask=True,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
do_return_2d_tensor=False,
|
|
||||||
use_relative_positions=True,
|
|
||||||
compute_type=mstype.float32),
|
|
||||||
'desc_inputs': [[64, 128, 768], [64, 128, 768], [64, 128, 128]],
|
|
||||||
'desc_bprop': [[8192, 768]]}),
|
|
||||||
('BertOutput', {
|
|
||||||
'block': BertOutput(in_channels=768,
|
|
||||||
out_channels=768,
|
|
||||||
initializer_range=0.02,
|
|
||||||
dropout_prob=0.1),
|
|
||||||
'desc_inputs': [[8192, 768], [8192, 768]],
|
|
||||||
'desc_bprop': [[8192, 768]]}),
|
|
||||||
('BertSelfAttention_0', {
|
|
||||||
'block': BertSelfAttention(batch_size=64,
|
|
||||||
seq_length=128,
|
|
||||||
hidden_size=768,
|
|
||||||
num_attention_heads=12,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
use_relative_positions=False,
|
|
||||||
compute_type=mstype.float32),
|
|
||||||
'desc_inputs': [[64, 128, 768], [64, 128, 128]],
|
|
||||||
'desc_bprop': [[8192, 768]]}),
|
|
||||||
('BertEncoderCell', {
|
|
||||||
'block': BertEncoderCell(batch_size=64,
|
|
||||||
hidden_size=768,
|
|
||||||
seq_length=128,
|
|
||||||
num_attention_heads=12,
|
|
||||||
intermediate_size=768,
|
|
||||||
attention_probs_dropout_prob=0.02,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
use_relative_positions=False,
|
|
||||||
hidden_act="gelu",
|
|
||||||
compute_type=mstype.float32),
|
|
||||||
'desc_inputs': [[64, 128, 768], [64, 128, 128]],
|
|
||||||
'desc_bprop': [[8192, 768]]}),
|
|
||||||
('BertTransformer_0', {
|
|
||||||
'block': BertTransformer(batch_size=1,
|
|
||||||
hidden_size=768,
|
|
||||||
seq_length=128,
|
|
||||||
num_hidden_layers=1,
|
|
||||||
num_attention_heads=12,
|
|
||||||
intermediate_size=768,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
use_relative_positions=False,
|
|
||||||
hidden_act="gelu",
|
|
||||||
compute_type=mstype.float32,
|
|
||||||
return_all_encoders=True),
|
|
||||||
'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
|
|
||||||
('BertTransformer_1', {
|
|
||||||
'block': BertTransformer(batch_size=64,
|
|
||||||
hidden_size=768,
|
|
||||||
seq_length=128,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=12,
|
|
||||||
intermediate_size=768,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
use_relative_positions=True,
|
|
||||||
hidden_act="gelu",
|
|
||||||
compute_type=mstype.float32,
|
|
||||||
return_all_encoders=False),
|
|
||||||
'desc_inputs': [[64, 128, 768], [64, 128, 128]]}),
|
|
||||||
('EmbeddingLookup', {
|
|
||||||
'block': EmbeddingLookup(vocab_size=32000,
|
|
||||||
embedding_size=768,
|
|
||||||
embedding_shape=[1, 128, 768],
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02),
|
|
||||||
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32))],
|
|
||||||
'desc_bprop': [[1, 128, 768], [1, 128, 768]],
|
|
||||||
'num_output': 2}),
|
|
||||||
('EmbeddingPostprocessor', {
|
|
||||||
'block': EmbeddingPostprocessor(embedding_size=768,
|
|
||||||
embedding_shape=[1, 128, 768],
|
|
||||||
use_token_type=True,
|
|
||||||
token_type_vocab_size=16,
|
|
||||||
use_one_hot_embeddings=False,
|
|
||||||
initializer_range=0.02,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
dropout_prob=0.1),
|
|
||||||
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)), [1, 128, 768]],
|
|
||||||
'desc_bprop': [[1, 128, 768]]}),
|
|
||||||
('CreateAttentionMaskFromInputMask', {
|
|
||||||
'block': CreateAttentionMaskFromInputMask(config=BertConfig(batch_size=1)),
|
|
||||||
'desc_inputs': [[128]],
|
|
||||||
'desc_bprop': [[1, 128, 128]]}),
|
|
||||||
('BertOutput_0', {
|
|
||||||
'block': BertOutput(in_channels=768,
|
|
||||||
out_channels=768,
|
|
||||||
initializer_range=0.02,
|
|
||||||
dropout_prob=0.1),
|
|
||||||
'desc_inputs': [[1, 768], [1, 768]],
|
|
||||||
'desc_bprop': [[1, 768]]}),
|
|
||||||
('BertTransformer_2', {
|
|
||||||
'block': bert_trans(),
|
|
||||||
'desc_inputs': [[1, 128, 768], [1, 128, 128]]}),
|
|
||||||
|
|
||||||
('BertModel', {
|
|
||||||
'block': BertModel(config=BertConfig(batch_size=1,
|
|
||||||
num_hidden_layers=1,
|
|
||||||
intermediate_size=768,
|
|
||||||
token_type_ids_from_dataset=False),
|
|
||||||
is_training=True),
|
|
||||||
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
|
||||||
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
|
||||||
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
|
||||||
'num_output': 3}),
|
|
||||||
|
|
||||||
('BertModel_1', {
|
|
||||||
'block': BertModel(config=BertConfig(batch_size=1,
|
|
||||||
num_hidden_layers=1,
|
|
||||||
intermediate_size=768,
|
|
||||||
token_type_ids_from_dataset=False),
|
|
||||||
is_training=False),
|
|
||||||
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
|
||||||
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
|
||||||
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
|
||||||
'num_output': 3}),
|
|
||||||
|
|
||||||
('BertModel_2', {
|
|
||||||
'block': BertModel(config=BertConfig(batch_size=1,
|
|
||||||
num_hidden_layers=1,
|
|
||||||
intermediate_size=768,
|
|
||||||
token_type_ids_from_dataset=False,
|
|
||||||
input_mask_from_dataset=False),
|
|
||||||
is_training=True),
|
|
||||||
'desc_inputs': [Tensor(np.random.rand(128).astype(np.int32)),
|
|
||||||
Tensor(np.random.rand(128).astype(np.int32)), [128]],
|
|
||||||
'desc_bprop': [[1, 128, 768], [1, 128, 768], [1, 128, 768]],
|
|
||||||
'num_output': 3}),
|
|
||||||
|
|
||||||
('BertPretrainingLoss', {
|
|
||||||
'block': BertPretrainingLoss(config=BertConfig(batch_size=1)),
|
|
||||||
'desc_inputs': [[32000], [20, 2], Tensor(np.array([1]).astype(np.int32)),
|
|
||||||
[20], Tensor(np.array([20]).astype(np.int32))],
|
|
||||||
'desc_bprop': [[1]],
|
|
||||||
'num_output': 1}),
|
|
||||||
('Dense_1', {
|
|
||||||
'block': nn.Dense(in_channels=768,
|
|
||||||
out_channels=3072,
|
|
||||||
activation='gelu',
|
|
||||||
weight_init=TruncatedNormal(0.02)),
|
|
||||||
'desc_inputs': [[3, 768]],
|
|
||||||
'desc_bprop': [[3, 3072]]}),
|
|
||||||
('Dense_2', {
|
|
||||||
'block': set_train(nn.Dense(in_channels=768,
|
|
||||||
out_channels=3072,
|
|
||||||
activation='gelu',
|
|
||||||
weight_init=TruncatedNormal(0.02),)),
|
|
||||||
'desc_inputs': [[3, 768]],
|
|
||||||
'desc_bprop': [[3, 3072]]}),
|
|
||||||
('GetNextSentenceOutput', {
|
|
||||||
'block': GetNextSentenceOutput(BertConfig(batch_size=1)),
|
|
||||||
'desc_inputs': [[128, 768]],
|
|
||||||
'desc_bprop': [[128, 2]]}),
|
|
||||||
('Adam_1', {
|
|
||||||
'block': set_train(TrainStepWrapForAdam(NetForAdam())),
|
|
||||||
'desc_inputs': [[1, 64], [1, 10]],
|
|
||||||
'skip': ['backward']}),
|
|
||||||
('Adam_2', {
|
|
||||||
'block': set_train(TrainStepWrapForAdam(GetNextSentenceOutput(BertConfig(batch_size=1)))),
|
|
||||||
'desc_inputs': [[128, 768], [128, 2]],
|
|
||||||
'skip': ['backward']}),
|
|
||||||
('AdamWeightDecayDynamicLR', {
|
|
||||||
'block': set_train(TrainStepWrapForAdamDynamicLr(NetForAdam())),
|
|
||||||
'desc_inputs': [[1, 64]],
|
|
||||||
'skip': ['backward']}),
|
|
||||||
('ClipGradients', {
|
|
||||||
'block': TempC2Wrap(clip_grad, 1, 1.0),
|
|
||||||
'desc_inputs': [tuple(convert(shp) for shp in [[1], [1], [1]])],
|
|
||||||
'skip': ['backward', 'exec']}),
|
|
||||||
]
|
|
||||||
|
|
||||||
test_case = functools.reduce(lambda x, y: x + y, [test_case_cell_ops])
|
|
||||||
# use -k to select certain testcast
|
|
||||||
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm
|
|
||||||
|
|
||||||
|
|
||||||
test_exec_case = filter(lambda x: 'skip' not in x[1] or
|
|
||||||
'exec' not in x[1]['skip'], test_case)
|
|
||||||
test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
|
|
||||||
'backward' not in x[1]['skip'] and 'backward_exec'
|
|
||||||
not in x[1]['skip'], test_case)
|
|
||||||
test_check_gradient_case = filter(lambda x: 'skip' not in x[1] or
|
|
||||||
'backward' not in x[1]['skip'] and 'backward_exec'
|
|
||||||
not in x[1]['skip'], test_case)
|
|
||||||
|
|
||||||
|
|
||||||
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
|
|
||||||
def test_exec():
|
|
||||||
return test_exec_case
|
|
||||||
|
|
||||||
|
|
||||||
@mindspore_test(pipeline_for_compile_grad_ge_graph_for_case_by_case_config)
|
|
||||||
def test_backward_exec():
|
|
||||||
return test_backward_exec_case
|
|
|
@ -1,66 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
""" test_embedding """
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore import dtype as mstype
|
|
||||||
from mindspore.model_zoo.Bert_NEZHA import EmbeddingLookup, EmbeddingPostprocessor
|
|
||||||
from ..ut_filter import non_graph_engine
|
|
||||||
|
|
||||||
|
|
||||||
@non_graph_engine
|
|
||||||
def test_check_embedding_lookup_1():
|
|
||||||
m = EmbeddingLookup(vocab_size=32000,
|
|
||||||
embedding_size=768,
|
|
||||||
embedding_shape=[1, 128, 768],
|
|
||||||
use_one_hot_embeddings=False)
|
|
||||||
m(Tensor(np.ones([128]), mstype.int32))
|
|
||||||
|
|
||||||
|
|
||||||
@non_graph_engine
|
|
||||||
def test_check_embedding_lookup_2():
|
|
||||||
m = EmbeddingLookup(vocab_size=32000,
|
|
||||||
embedding_size=768,
|
|
||||||
embedding_shape=[1, 128, 768],
|
|
||||||
use_one_hot_embeddings=True)
|
|
||||||
m(Tensor(np.ones([128]), mstype.int32))
|
|
||||||
|
|
||||||
|
|
||||||
@non_graph_engine
|
|
||||||
def test_check_embedding_lookup_3():
|
|
||||||
m = EmbeddingLookup(vocab_size=32000,
|
|
||||||
embedding_size=768,
|
|
||||||
embedding_shape=[1, 128, 768],
|
|
||||||
use_one_hot_embeddings=True,
|
|
||||||
initializer_range=0.01)
|
|
||||||
m(Tensor(np.ones([128]), mstype.int32))
|
|
||||||
|
|
||||||
|
|
||||||
@non_graph_engine
|
|
||||||
def test_embedding_post_1():
|
|
||||||
m = EmbeddingPostprocessor(embedding_size=768,
|
|
||||||
embedding_shape=[1, 128, 768],
|
|
||||||
use_token_type=True)
|
|
||||||
m(Tensor(np.ones([128]), mstype.int32), Tensor(np.ones([1, 128, 768]), mstype.float32))
|
|
||||||
|
|
||||||
|
|
||||||
@non_graph_engine
|
|
||||||
def test_embedding_post_2():
|
|
||||||
m = EmbeddingPostprocessor(embedding_size=768,
|
|
||||||
embedding_shape=[1, 128, 768],
|
|
||||||
use_token_type=True,
|
|
||||||
initializer_range=0.3)
|
|
||||||
m(Tensor(np.ones([128]), mstype.int32), Tensor(np.ones([1, 128, 768]), mstype.float32))
|
|
Loading…
Reference in New Issue