!6727 modify transformer hub file
Merge pull request !6727 from yuchaojie/r1
This commit is contained in:
commit
7acc66a3a5
|
@ -41,8 +41,16 @@ def create_network(name, *args, **kwargs):
|
|||
Create transformer network for large.
|
||||
'''
|
||||
if name == 'transformer_large':
|
||||
if "batch_size" in kwargs:
|
||||
transformer_net_cfg_large.batch_size = kwargs["batch_size"]
|
||||
if "seq_length" in kwargs:
|
||||
transformer_net_cfg_large.seq_length = kwargs["seq_length"]
|
||||
if "vocab_size" in kwargs:
|
||||
transformer_net_cfg_large.vocab_size = kwargs["vocab_size"]
|
||||
is_training = kwargs.get("is_training", False)
|
||||
if not is_training:
|
||||
transformer_net_cfg_large.batch_size = 1
|
||||
transformer_net_cfg_large.hidden_dropout_prob = 0.
|
||||
transformer_net_cfg_large.attention_probs_dropout_prob = 0.
|
||||
return TransformerModel(transformer_net_cfg_large, is_training, *args)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
Loading…
Reference in New Issue