pr to master #8

Open
m7grui4p8 wants to merge 201 commits from p69201753/mindspore:cpu-kernel-reuse-1 into master
8 changed files with 82 additions and 92 deletions
Showing only changes of commit 7af8e0a9cf - Show all commits

View File

@ -59,11 +59,11 @@ def create_network(name, *args, **kwargs):
if name == 'bert_base':
if "seq_length" in kwargs:
bert_net_cfg_base.seq_length = kwargs["seq_length"]
is_training = kwargs.get("is_training", default=False)
is_training = kwargs.get("is_training", False)
return BertModel(bert_net_cfg_base, is_training, *args)
if name == 'bert_nezha':
if "seq_length" in kwargs:
bert_net_cfg_nezha.seq_length = kwargs["seq_length"]
is_training = kwargs.get("is_training", default=False)
is_training = kwargs.get("is_training", False)
return BertModel(bert_net_cfg_nezha, is_training, *args)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -207,6 +207,7 @@ options:
`gd_config.py` and `td_config.py` contain parameters of BERT model and options for optimizer and lossscale.
### Options:
```
batch_size batch size of input dataset: N, default is 16
Parameters for lossscale:
loss_scale_value initial value of loss scale: N, default is 2^8
scale_factor factor used to update loss scale: N, default is 2
@ -223,7 +224,6 @@ Parameters for optimizer:
### Parameters:
```
Parameters for bert network:
batch_size batch size of input dataset: N, default is 16
seq_length length of input sequence: N, default is 128
vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 30522
hidden_size size of bert encoder layers: N
@ -239,8 +239,6 @@ Parameters for bert network:
type_vocab_size size of token type vocab: N, default is 2
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
use_relative_positions use relative positions or not: True | False, default is False
input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True
token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
```

View File

@ -0,0 +1,49 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''
Bert hub interface for bert base and bert nezha
'''
from src.tinybert_model import TinyBertModel
from src.tinybert_model import BertConfig
import mindspore.common.dtype as mstype
tinybert_student_net_cfg = BertConfig(
seq_length=128,
vocab_size=30522,
hidden_size=384,
num_hidden_layers=4,
num_attention_heads=12,
intermediate_size=1536,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
dtype=mstype.float32,
compute_type=mstype.float16
)
def create_network(name, *args, **kwargs):
'''
Create tinybert network.
'''
if name == "tinybert":
if "seq_length" in kwargs:
tinybert_student_net_cfg.seq_length = kwargs["seq_length"]
is_training = kwargs.get("is_training", False)
return TinyBertModel(tinybert_student_net_cfg, is_training, *args)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -110,7 +110,7 @@ def run_general_distill():
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
dataset = create_tinybert_dataset('gd', common_cfg.batch_size, device_num, rank,
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir,
data_type=dataset_type)
dataset_size = dataset.get_dataset_size()

View File

@ -29,7 +29,7 @@ from mindspore import log as logger
from src.dataset import create_tinybert_dataset, DataType
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
from src.assessment_method import Accuracy
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
from src.tinybert_model import BertModelCLS
@ -130,7 +130,7 @@ def run_predistill():
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
dataset = create_tinybert_dataset('td', cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir,
data_type=dataset_type)
@ -194,7 +194,7 @@ def run_task_distill(ckpt_file):
rank = 0
device_num = 1
train_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
train_dataset = create_tinybert_dataset('td', cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir)
@ -224,7 +224,7 @@ def run_task_distill(ckpt_file):
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)
eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.eval_data_dir, args_opt.schema_dir)
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())
@ -269,7 +269,7 @@ def do_eval_standalone():
load_param_into_net(eval_model, new_param_dict)
eval_model.set_train(False)
eval_dataset = create_tinybert_dataset('td', batch_size=td_student_net_cfg.batch_size,
eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size,
device_num=1, rank=0, do_shuffle="false",
data_dir=args_opt.eval_data_dir,
schema_dir=args_opt.schema_dir)

View File

@ -20,6 +20,7 @@ from easydict import EasyDict as edict
from .tinybert_model import BertConfig
common_cfg = edict({
'batch_size': 32,
'loss_scale_value': 2 ** 16,
'scale_factor': 2,
'scale_window': 1000,
@ -38,7 +39,6 @@ teacher network: The BERT-base network.
student network: The network which is inherited from teacher network.
'''
bert_teacher_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=30522,
hidden_size=768,
@ -52,13 +52,10 @@ bert_teacher_net_cfg = BertConfig(
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16
)
bert_student_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=30522,
hidden_size=384,
@ -72,8 +69,6 @@ bert_student_net_cfg = BertConfig(
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16
)

View File

@ -20,6 +20,7 @@ from easydict import EasyDict as edict
from .tinybert_model import BertConfig
phase1_cfg = edict({
'batch_size': 32,
'loss_scale_value': 2 ** 8,
'scale_factor': 2,
'scale_window': 50,
@ -36,6 +37,7 @@ phase1_cfg = edict({
})
phase2_cfg = edict({
'batch_size': 32,
'loss_scale_value': 2 ** 16,
'scale_factor': 2,
'scale_window': 50,
@ -51,13 +53,16 @@ phase2_cfg = edict({
}),
})
eval_cfg = edict({
'batch_size': 32,
})
'''
Including two kinds of network: \
teacher network: The BERT-base network with finetune.
student network: The model which is producted by GD phase.
'''
td_teacher_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=30522,
hidden_size=768,
@ -71,13 +76,10 @@ td_teacher_net_cfg = BertConfig(
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16
)
td_student_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=30522,
hidden_size=384,
@ -91,8 +93,6 @@ td_student_net_cfg = BertConfig(
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16
)

View File

@ -32,7 +32,6 @@ class BertConfig:
Configuration for `BertModel`.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence. Default: 128.
vocab_size (int): The shape of each embedding vector. Default: 32000.
hidden_size (int): Size of the bert encoder layers. Default: 768.
@ -52,15 +51,10 @@ class BertConfig:
type_vocab_size (int): Size of token type vocab. Default: 16.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset. Default: True.
token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded
from dataset. Default: True.
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
"""
def __init__(self,
batch_size,
seq_length=128,
vocab_size=32000,
hidden_size=768,
@ -74,11 +68,8 @@ class BertConfig:
type_vocab_size=16,
initializer_range=0.02,
use_relative_positions=False,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float32):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
@ -91,8 +82,6 @@ class BertConfig:
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.input_mask_from_dataset = input_mask_from_dataset
self.token_type_ids_from_dataset = token_type_ids_from_dataset
self.use_relative_positions = use_relative_positions
self.dtype = dtype
self.compute_type = compute_type
@ -390,7 +379,6 @@ class BertAttention(nn.Cell):
Apply multi-headed attention from "from_tensor" to "to_tensor".
Args:
batch_size (int): Batch size of input datasets.
from_tensor_width (int): Size of last dim of from_tensor.
to_tensor_width (int): Size of last dim of to_tensor.
from_seq_length (int): Length of from_tensor sequence.
@ -411,7 +399,6 @@ class BertAttention(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
from_tensor_width,
to_tensor_width,
from_seq_length,
@ -429,7 +416,6 @@ class BertAttention(nn.Cell):
use_relative_positions=False,
compute_type=mstype.float32):
super(BertAttention, self).__init__()
self.batch_size = batch_size
self.from_seq_length = from_seq_length
self.to_seq_length = to_seq_length
self.num_attention_heads = num_attention_heads
@ -454,9 +440,8 @@ class BertAttention(nn.Cell):
units,
activation=value_act,
weight_init=weight).to_float(compute_type)
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
self.shape_to = (
batch_size, to_seq_length, num_attention_heads, size_per_head)
self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head)
self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head)
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.multiply = P.Mul()
self.transpose = P.Transpose()
@ -464,7 +449,6 @@ class BertAttention(nn.Cell):
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
self.batch_num = batch_size * num_attention_heads
self.matmul = P.BatchMatMul()
self.softmax = nn.Softmax()
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
@ -475,9 +459,9 @@ class BertAttention(nn.Cell):
self.cast = P.Cast()
self.get_dtype = P.DType()
if do_return_2d_tensor:
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
self.shape_return = (-1, num_attention_heads * size_per_head)
else:
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
self.shape_return = (-1, from_seq_length, num_attention_heads * size_per_head)
self.cast_compute_type = SaturateCast(dst_type=compute_type)
if self.use_relative_positions:
self._generate_relative_positions_embeddings = \
@ -510,7 +494,7 @@ class BertAttention(nn.Cell):
# query_layer_r is [F, B * N, H]
query_layer_r = self.reshape(query_layer_t,
(self.from_seq_length,
self.batch_num,
-1,
self.size_per_head))
# key_position_scores is [F, B * N, F|T]
key_position_scores = self.matmul_trans_b(query_layer_r,
@ -518,7 +502,7 @@ class BertAttention(nn.Cell):
# key_position_scores_r is [F, B, N, F|T]
key_position_scores_r = self.reshape(key_position_scores,
(self.from_seq_length,
self.batch_size,
-1,
self.num_attention_heads,
self.from_seq_length))
# key_position_scores_r_t is [B, N, F, F|T]
@ -548,7 +532,7 @@ class BertAttention(nn.Cell):
attention_probs_r = self.reshape(
attention_probs_t,
(self.from_seq_length,
self.batch_num,
-1,
self.to_seq_length))
# value_position_scores is [F, B * N, H]
value_position_scores = self.matmul(attention_probs_r,
@ -556,7 +540,7 @@ class BertAttention(nn.Cell):
# value_position_scores_r is [F, B, N, H]
value_position_scores_r = self.reshape(value_position_scores,
(self.from_seq_length,
self.batch_size,
-1,
self.num_attention_heads,
self.size_per_head))
# value_position_scores_r_t is [B, N, F, H]
@ -572,7 +556,6 @@ class BertSelfAttention(nn.Cell):
Apply self-attention.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence.
hidden_size (int): Size of the bert encoder layers.
num_attention_heads (int): Number of attention heads. Default: 12.
@ -585,7 +568,6 @@ class BertSelfAttention(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
seq_length,
hidden_size,
num_attention_heads=12,
@ -601,7 +583,6 @@ class BertSelfAttention(nn.Cell):
"of attention heads (%d)" % (hidden_size, num_attention_heads))
self.size_per_head = int(hidden_size / num_attention_heads)
self.attention = BertAttention(
batch_size=batch_size,
from_tensor_width=hidden_size,
to_tensor_width=hidden_size,
from_seq_length=seq_length,
@ -636,7 +617,6 @@ class BertEncoderCell(nn.Cell):
Encoder cells used in BertTransformer.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the bert encoder layers. Default: 768.
seq_length (int): Length of input sequence. Default: 512.
num_attention_heads (int): Number of attention heads. Default: 12.
@ -651,7 +631,6 @@ class BertEncoderCell(nn.Cell):
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
hidden_size=768,
seq_length=512,
num_attention_heads=12,
@ -665,7 +644,6 @@ class BertEncoderCell(nn.Cell):
compute_type=mstype.float32):
super(BertEncoderCell, self).__init__()
self.attention = BertSelfAttention(
batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
num_attention_heads=num_attention_heads,
@ -700,7 +678,6 @@ class BertTransformer(nn.Cell):
Multi-layer bert transformer.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the encoder layers.
seq_length (int): Length of input sequence.
num_hidden_layers (int): Number of hidden layers in encoder cells.
@ -717,7 +694,6 @@ class BertTransformer(nn.Cell):
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
"""
def __init__(self,
batch_size,
hidden_size,
seq_length,
num_hidden_layers,
@ -735,8 +711,7 @@ class BertTransformer(nn.Cell):
self.return_all_encoders = return_all_encoders
layers = []
for _ in range(num_hidden_layers):
layer = BertEncoderCell(batch_size=batch_size,
hidden_size=hidden_size,
layer = BertEncoderCell(hidden_size=hidden_size,
seq_length=seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
@ -751,7 +726,7 @@ class BertTransformer(nn.Cell):
self.layers = nn.CellList(layers)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
self.out_shape = (batch_size, seq_length, hidden_size)
self.out_shape = (-1, seq_length, hidden_size)
def construct(self, input_tensor, attention_mask):
"""bert transformer"""
prev_output = self.reshape(input_tensor, self.shape)
@ -782,22 +757,13 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
"""
def __init__(self, config):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.input_mask_from_dataset = config.input_mask_from_dataset
self.input_mask = None
if not self.input_mask_from_dataset:
self.input_mask = initializer(
"ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.shape = (config.batch_size, 1, config.seq_length)
self.broadcast_ones = initializer(
"ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor()
self.batch_matmul = P.BatchMatMul()
self.shape = (-1, 1, config.seq_length)
def construct(self, input_mask):
if not self.input_mask_from_dataset:
input_mask = self.input_mask
input_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
attention_mask = self.batch_matmul(self.broadcast_ones, input_mask)
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
return attention_mask
class BertModel(nn.Cell):
@ -818,20 +784,14 @@ class BertModel(nn.Cell):
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
self.input_mask_from_dataset = config.input_mask_from_dataset
self.token_type_ids_from_dataset = config.token_type_ids_from_dataset
self.batch_size = config.batch_size
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embedding_size = config.hidden_size
self.token_type_ids = None
self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [self.batch_size, self.seq_length,
output_embedding_shape = [-1, self.seq_length,
self.embedding_size]
if not self.token_type_ids_from_dataset:
self.token_type_ids = initializer(
"zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor()
self.bert_embedding_lookup = EmbeddingLookup(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
@ -849,7 +809,6 @@ class BertModel(nn.Cell):
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.bert_encoder = BertTransformer(
batch_size=self.batch_size,
hidden_size=self.hidden_size,
seq_length=self.seq_length,
num_attention_heads=config.num_attention_heads,
@ -876,8 +835,6 @@ class BertModel(nn.Cell):
def construct(self, input_ids, token_type_ids, input_mask):
"""bert model"""
# embedding
if not self.token_type_ids_from_dataset:
token_type_ids = self.token_type_ids
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
@ -889,7 +846,7 @@ class BertModel(nn.Cell):
# pooler
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(self.batch_size, 1, self.hidden_size),
(-1, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)
@ -921,20 +878,14 @@ class TinyBertModel(nn.Cell):
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
self.input_mask_from_dataset = config.input_mask_from_dataset
self.token_type_ids_from_dataset = config.token_type_ids_from_dataset
self.batch_size = config.batch_size
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embedding_size = config.hidden_size
self.token_type_ids = None
self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [self.batch_size, self.seq_length,
output_embedding_shape = [-1, self.seq_length,
self.embedding_size]
if not self.token_type_ids_from_dataset:
self.token_type_ids = initializer(
"zeros", [self.batch_size, self.seq_length], mstype.int32).to_tensor()
self.tinybert_embedding_lookup = EmbeddingLookup(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
@ -952,7 +903,6 @@ class TinyBertModel(nn.Cell):
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.tinybert_encoder = BertTransformer(
batch_size=self.batch_size,
hidden_size=self.hidden_size,
seq_length=self.seq_length,
num_attention_heads=config.num_attention_heads,
@ -979,8 +929,6 @@ class TinyBertModel(nn.Cell):
def construct(self, input_ids, token_type_ids, input_mask):
"""tiny bert model"""
# embedding
if not self.token_type_ids_from_dataset:
token_type_ids = self.token_type_ids
word_embeddings, embedding_tables = self.tinybert_embedding_lookup(input_ids)
embedding_output = self.tinybert_embedding_postprocessor(token_type_ids,
word_embeddings)
@ -993,7 +941,7 @@ class TinyBertModel(nn.Cell):
# pooler
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(self.batch_size, 1, self.hidden_size),
(-1, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)