bugfix bert perf test

This commit is contained in:
yoonlee666 2020-09-18 17:06:35 +08:00
parent faa0a6ad45
commit f12291a38c
1 changed files with 6 additions and 12 deletions

View File

@ -54,13 +54,12 @@ def load_test_data(batch_size=1):
return ret return ret
def get_config(version='base', batch_size=1): def get_config(version='base'):
""" """
get_config definition get_config definition
""" """
if version == 'base': if version == 'base':
return BertConfig( return BertConfig(
batch_size=batch_size,
seq_length=128, seq_length=128,
vocab_size=21128, vocab_size=21128,
hidden_size=768, hidden_size=768,
@ -74,13 +73,10 @@ def get_config(version='base', batch_size=1):
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
use_relative_positions=True, use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32, dtype=mstype.float32,
compute_type=mstype.float32) compute_type=mstype.float32)
if version == 'large': if version == 'large':
return BertConfig( return BertConfig(
batch_size=batch_size,
seq_length=128, seq_length=128,
vocab_size=21128, vocab_size=21128,
hidden_size=1024, hidden_size=1024,
@ -94,11 +90,9 @@ def get_config(version='base', batch_size=1):
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
use_relative_positions=True, use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32, dtype=mstype.float32,
compute_type=mstype.float32) compute_type=mstype.float32)
return BertConfig(batch_size=batch_size) return BertConfig()
class BertLearningRate(lr_schedules.LearningRateSchedule): class BertLearningRate(lr_schedules.LearningRateSchedule):
@ -143,7 +137,7 @@ def test_bert_train():
batch_size = int(os.getenv('BATCH_SIZE', '1')) batch_size = int(os.getenv('BATCH_SIZE', '1'))
inputs = load_test_data(batch_size) inputs = load_test_data(batch_size)
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
lr = BertLearningRate(10) lr = BertLearningRate(10)
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
@ -168,7 +162,7 @@ def test_bert_withlossscale_train():
scaling_sens = Tensor(np.ones([1]).astype(np.float32)) scaling_sens = Tensor(np.ones([1]).astype(np.float32))
inputs = load_test_data(batch_size) + (scaling_sens,) inputs = load_test_data(batch_size) + (scaling_sens,)
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
lr = BertLearningRate(10) lr = BertLearningRate(10)
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
@ -195,7 +189,7 @@ def bert_withlossscale_manager_train():
batch_size = int(os.getenv('BATCH_SIZE', '1')) batch_size = int(os.getenv('BATCH_SIZE', '1'))
inputs = load_test_data(batch_size) inputs = load_test_data(batch_size)
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
lr = BertLearningRate(10) lr = BertLearningRate(10)
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
@ -223,7 +217,7 @@ def bert_withlossscale_manager_train_feed():
scaling_sens = Tensor(np.ones([1]).astype(np.float32)) scaling_sens = Tensor(np.ones([1]).astype(np.float32))
inputs = load_test_data(batch_size) + (scaling_sens,) inputs = load_test_data(batch_size) + (scaling_sens,)
config = get_config(version=version, batch_size=batch_size) config = get_config(version=version)
netwithloss = BertNetworkWithLoss(config, True) netwithloss = BertNetworkWithLoss(config, True)
lr = BertLearningRate(10) lr = BertLearningRate(10)
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr) optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)