130 lines
7.0 KiB
Python
130 lines
7.0 KiB
Python
|
import torch
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
BERT_VOCAB_SIZE = 30522
|
||
|
|
||
|
|
||
|
def get_match_value(name, state_dict_numpy, prefix):
|
||
|
"""
|
||
|
Need be overridden towards different models, here for UnifiedTransformer Model
|
||
|
"""
|
||
|
if name == 'embedder.token_embedding.weight':
|
||
|
return state_dict_numpy[f'{prefix}embeddings.word_embeddings.weight']
|
||
|
elif name == 'embedder.pos_embedding.weight':
|
||
|
return state_dict_numpy[f'{prefix}embeddings.position_embeddings.weight']
|
||
|
elif name == 'embedder.type_embedding.weight':
|
||
|
return state_dict_numpy.get(f'{prefix}embeddings.token_type_embeddings.weight')
|
||
|
elif name == 'embedder.turn_embedding.weight':
|
||
|
return None
|
||
|
elif name == 'embed_layer_norm.weight':
|
||
|
if f'{prefix}embeddings.LayerNorm.weight' in state_dict_numpy:
|
||
|
return state_dict_numpy[f'{prefix}embeddings.LayerNorm.weight']
|
||
|
else:
|
||
|
return state_dict_numpy[f'{prefix}embeddings.LayerNorm.gamma']
|
||
|
elif name == 'embed_layer_norm.bias':
|
||
|
if f'{prefix}embeddings.LayerNorm.bias' in state_dict_numpy:
|
||
|
return state_dict_numpy[f'{prefix}embeddings.LayerNorm.bias']
|
||
|
else:
|
||
|
return state_dict_numpy[f'{prefix}embeddings.LayerNorm.beta']
|
||
|
elif name == 'pooler.0.weight':
|
||
|
return state_dict_numpy.get(f'{prefix}pooler.dense.weight')
|
||
|
elif name == 'pooler.0.bias':
|
||
|
return state_dict_numpy.get(f'{prefix}pooler.dense.bias')
|
||
|
elif name == 'mlm_transform.0.weight':
|
||
|
return state_dict_numpy.get('cls.predictions.transform.dense.weight')
|
||
|
elif name == 'mlm_transform.0.bias':
|
||
|
return state_dict_numpy.get('cls.predictions.transform.dense.bias')
|
||
|
elif name == 'mlm_transform.2.weight':
|
||
|
if 'cls.predictions.transform.LayerNorm.weight' in state_dict_numpy:
|
||
|
return state_dict_numpy.get('cls.predictions.transform.LayerNorm.weight')
|
||
|
else:
|
||
|
return state_dict_numpy.get('cls.predictions.transform.LayerNorm.gamma')
|
||
|
elif name == 'mlm_transform.2.bias':
|
||
|
if 'cls.predictions.transform.LayerNorm.bias' in state_dict_numpy:
|
||
|
return state_dict_numpy.get('cls.predictions.transform.LayerNorm.bias')
|
||
|
else:
|
||
|
return state_dict_numpy.get('cls.predictions.transform.LayerNorm.beta')
|
||
|
elif name == 'mlm_bias':
|
||
|
return state_dict_numpy.get('cls.predictions.bias')
|
||
|
else:
|
||
|
num = name.split('.')[1]
|
||
|
assert num in [str(i) for i in range(12)]
|
||
|
if name == f'layers.{num}.attn.linear_qkv.weight':
|
||
|
q = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.query.weight']
|
||
|
k = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.key.weight']
|
||
|
v = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.value.weight']
|
||
|
qkv_weight = torch.cat([q, k, v], dim=0)
|
||
|
return qkv_weight
|
||
|
elif name == f'layers.{num}.attn.linear_qkv.bias':
|
||
|
q = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.query.bias']
|
||
|
k = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.key.bias']
|
||
|
v = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.value.bias']
|
||
|
qkv_bias = torch.cat([q, k, v], dim=0)
|
||
|
return qkv_bias
|
||
|
elif name == f'layers.{num}.attn.linear_out.weight':
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.dense.weight']
|
||
|
elif name == f'layers.{num}.attn.linear_out.bias':
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.dense.bias']
|
||
|
elif name == f'layers.{num}.attn_norm.weight':
|
||
|
if f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.weight' in state_dict_numpy:
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.weight']
|
||
|
else:
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.gamma']
|
||
|
elif name == f'layers.{num}.attn_norm.bias':
|
||
|
if f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.bias' in state_dict_numpy:
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.bias']
|
||
|
else:
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.beta']
|
||
|
elif name == f'layers.{num}.ff.linear_hidden.0.weight':
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.intermediate.dense.weight']
|
||
|
elif name == f'layers.{num}.ff.linear_hidden.0.bias':
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.intermediate.dense.bias']
|
||
|
elif name == f'layers.{num}.ff.linear_out.weight':
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.dense.weight']
|
||
|
elif name == f'layers.{num}.ff.linear_out.bias':
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.dense.bias']
|
||
|
elif name == f'layers.{num}.ff_norm.weight':
|
||
|
if f'{prefix}encoder.layer.{num}.output.LayerNorm.weight' in state_dict_numpy:
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.LayerNorm.weight']
|
||
|
else:
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.LayerNorm.gamma']
|
||
|
elif name == f'layers.{num}.ff_norm.bias':
|
||
|
if f'{prefix}encoder.layer.{num}.output.LayerNorm.bias' in state_dict_numpy:
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.LayerNorm.bias']
|
||
|
else:
|
||
|
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.LayerNorm.beta']
|
||
|
else:
|
||
|
raise ValueError(f'ERROR: Param "{name}" can not be loaded in Space Model!')
|
||
|
|
||
|
|
||
|
def hug2space(input_file, input_template, output_file):
|
||
|
state_dict_output = OrderedDict()
|
||
|
state_dict_input = torch.load(input_file, map_location=lambda storage, loc: storage)
|
||
|
state_dict_template = torch.load(input_template, map_location=lambda storage, loc: storage)
|
||
|
prefix = 'bert.' if list(state_dict_input.keys())[0].startswith('bert.') else ''
|
||
|
|
||
|
for name, value in state_dict_template.items():
|
||
|
match_value = get_match_value(name, state_dict_input, prefix)
|
||
|
if match_value is not None:
|
||
|
assert match_value.ndim == value.ndim
|
||
|
if match_value.shape != value.shape:
|
||
|
assert value.size(0) == BERT_VOCAB_SIZE and match_value.size(0) > BERT_VOCAB_SIZE
|
||
|
match_value = match_value[:BERT_VOCAB_SIZE]
|
||
|
dtype = value.dtype
|
||
|
device = value.device
|
||
|
state_dict_output[name] = torch.tensor(match_value, dtype=dtype, device=device)
|
||
|
else:
|
||
|
print(f'WARNING: Param "{name}" can not be loaded in Space Model.')
|
||
|
|
||
|
torch.save(state_dict_output, output_file)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# Using Example
|
||
|
input_template = '../model/template.model'
|
||
|
input_file = '../model/bert-base-uncased/pytorch_model.bin'
|
||
|
output_file = '../model/bert-base-uncased.model'
|
||
|
|
||
|
hug2space(input_file=input_file, input_template=input_template, output_file=output_file)
|
||
|
print(f'Converted!')
|