forked from mindspore-Ecosystem/mindspore
bugfix bert perf test
This commit is contained in:
parent
faa0a6ad45
commit
f12291a38c
|
@ -54,13 +54,12 @@ def load_test_data(batch_size=1):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def get_config(version='base', batch_size=1):
|
def get_config(version='base'):
|
||||||
"""
|
"""
|
||||||
get_config definition
|
get_config definition
|
||||||
"""
|
"""
|
||||||
if version == 'base':
|
if version == 'base':
|
||||||
return BertConfig(
|
return BertConfig(
|
||||||
batch_size=batch_size,
|
|
||||||
seq_length=128,
|
seq_length=128,
|
||||||
vocab_size=21128,
|
vocab_size=21128,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
|
@ -74,13 +73,10 @@ def get_config(version='base', batch_size=1):
|
||||||
type_vocab_size=2,
|
type_vocab_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
use_relative_positions=True,
|
use_relative_positions=True,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
token_type_ids_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float32)
|
compute_type=mstype.float32)
|
||||||
if version == 'large':
|
if version == 'large':
|
||||||
return BertConfig(
|
return BertConfig(
|
||||||
batch_size=batch_size,
|
|
||||||
seq_length=128,
|
seq_length=128,
|
||||||
vocab_size=21128,
|
vocab_size=21128,
|
||||||
hidden_size=1024,
|
hidden_size=1024,
|
||||||
|
@ -94,11 +90,9 @@ def get_config(version='base', batch_size=1):
|
||||||
type_vocab_size=2,
|
type_vocab_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
use_relative_positions=True,
|
use_relative_positions=True,
|
||||||
input_mask_from_dataset=True,
|
|
||||||
token_type_ids_from_dataset=True,
|
|
||||||
dtype=mstype.float32,
|
dtype=mstype.float32,
|
||||||
compute_type=mstype.float32)
|
compute_type=mstype.float32)
|
||||||
return BertConfig(batch_size=batch_size)
|
return BertConfig()
|
||||||
|
|
||||||
|
|
||||||
class BertLearningRate(lr_schedules.LearningRateSchedule):
|
class BertLearningRate(lr_schedules.LearningRateSchedule):
|
||||||
|
@ -143,7 +137,7 @@ def test_bert_train():
|
||||||
batch_size = int(os.getenv('BATCH_SIZE', '1'))
|
batch_size = int(os.getenv('BATCH_SIZE', '1'))
|
||||||
inputs = load_test_data(batch_size)
|
inputs = load_test_data(batch_size)
|
||||||
|
|
||||||
config = get_config(version=version, batch_size=batch_size)
|
config = get_config(version=version)
|
||||||
netwithloss = BertNetworkWithLoss(config, True)
|
netwithloss = BertNetworkWithLoss(config, True)
|
||||||
lr = BertLearningRate(10)
|
lr = BertLearningRate(10)
|
||||||
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
|
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
|
||||||
|
@ -168,7 +162,7 @@ def test_bert_withlossscale_train():
|
||||||
scaling_sens = Tensor(np.ones([1]).astype(np.float32))
|
scaling_sens = Tensor(np.ones([1]).astype(np.float32))
|
||||||
inputs = load_test_data(batch_size) + (scaling_sens,)
|
inputs = load_test_data(batch_size) + (scaling_sens,)
|
||||||
|
|
||||||
config = get_config(version=version, batch_size=batch_size)
|
config = get_config(version=version)
|
||||||
netwithloss = BertNetworkWithLoss(config, True)
|
netwithloss = BertNetworkWithLoss(config, True)
|
||||||
lr = BertLearningRate(10)
|
lr = BertLearningRate(10)
|
||||||
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
|
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
|
||||||
|
@ -195,7 +189,7 @@ def bert_withlossscale_manager_train():
|
||||||
batch_size = int(os.getenv('BATCH_SIZE', '1'))
|
batch_size = int(os.getenv('BATCH_SIZE', '1'))
|
||||||
inputs = load_test_data(batch_size)
|
inputs = load_test_data(batch_size)
|
||||||
|
|
||||||
config = get_config(version=version, batch_size=batch_size)
|
config = get_config(version=version)
|
||||||
netwithloss = BertNetworkWithLoss(config, True)
|
netwithloss = BertNetworkWithLoss(config, True)
|
||||||
lr = BertLearningRate(10)
|
lr = BertLearningRate(10)
|
||||||
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
|
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
|
||||||
|
@ -223,7 +217,7 @@ def bert_withlossscale_manager_train_feed():
|
||||||
scaling_sens = Tensor(np.ones([1]).astype(np.float32))
|
scaling_sens = Tensor(np.ones([1]).astype(np.float32))
|
||||||
inputs = load_test_data(batch_size) + (scaling_sens,)
|
inputs = load_test_data(batch_size) + (scaling_sens,)
|
||||||
|
|
||||||
config = get_config(version=version, batch_size=batch_size)
|
config = get_config(version=version)
|
||||||
netwithloss = BertNetworkWithLoss(config, True)
|
netwithloss = BertNetworkWithLoss(config, True)
|
||||||
lr = BertLearningRate(10)
|
lr = BertLearningRate(10)
|
||||||
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
|
optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
|
||||||
|
|
Loading…
Reference in New Issue