DAMO-ConvAI/space-2/tools/hug2space.py

130 lines
7.0 KiB
Python

import torch
from collections import OrderedDict
BERT_VOCAB_SIZE = 30522
def get_match_value(name, state_dict_numpy, prefix):
"""
Need be overridden towards different models, here for UnifiedTransformer Model
"""
if name == 'embedder.token_embedding.weight':
return state_dict_numpy[f'{prefix}embeddings.word_embeddings.weight']
elif name == 'embedder.pos_embedding.weight':
return state_dict_numpy[f'{prefix}embeddings.position_embeddings.weight']
elif name == 'embedder.type_embedding.weight':
return state_dict_numpy.get(f'{prefix}embeddings.token_type_embeddings.weight')
elif name == 'embedder.turn_embedding.weight':
return None
elif name == 'embed_layer_norm.weight':
if f'{prefix}embeddings.LayerNorm.weight' in state_dict_numpy:
return state_dict_numpy[f'{prefix}embeddings.LayerNorm.weight']
else:
return state_dict_numpy[f'{prefix}embeddings.LayerNorm.gamma']
elif name == 'embed_layer_norm.bias':
if f'{prefix}embeddings.LayerNorm.bias' in state_dict_numpy:
return state_dict_numpy[f'{prefix}embeddings.LayerNorm.bias']
else:
return state_dict_numpy[f'{prefix}embeddings.LayerNorm.beta']
elif name == 'pooler.0.weight':
return state_dict_numpy.get(f'{prefix}pooler.dense.weight')
elif name == 'pooler.0.bias':
return state_dict_numpy.get(f'{prefix}pooler.dense.bias')
elif name == 'mlm_transform.0.weight':
return state_dict_numpy.get('cls.predictions.transform.dense.weight')
elif name == 'mlm_transform.0.bias':
return state_dict_numpy.get('cls.predictions.transform.dense.bias')
elif name == 'mlm_transform.2.weight':
if 'cls.predictions.transform.LayerNorm.weight' in state_dict_numpy:
return state_dict_numpy.get('cls.predictions.transform.LayerNorm.weight')
else:
return state_dict_numpy.get('cls.predictions.transform.LayerNorm.gamma')
elif name == 'mlm_transform.2.bias':
if 'cls.predictions.transform.LayerNorm.bias' in state_dict_numpy:
return state_dict_numpy.get('cls.predictions.transform.LayerNorm.bias')
else:
return state_dict_numpy.get('cls.predictions.transform.LayerNorm.beta')
elif name == 'mlm_bias':
return state_dict_numpy.get('cls.predictions.bias')
else:
num = name.split('.')[1]
assert num in [str(i) for i in range(12)]
if name == f'layers.{num}.attn.linear_qkv.weight':
q = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.query.weight']
k = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.key.weight']
v = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.value.weight']
qkv_weight = torch.cat([q, k, v], dim=0)
return qkv_weight
elif name == f'layers.{num}.attn.linear_qkv.bias':
q = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.query.bias']
k = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.key.bias']
v = state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.self.value.bias']
qkv_bias = torch.cat([q, k, v], dim=0)
return qkv_bias
elif name == f'layers.{num}.attn.linear_out.weight':
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.dense.weight']
elif name == f'layers.{num}.attn.linear_out.bias':
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.dense.bias']
elif name == f'layers.{num}.attn_norm.weight':
if f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.weight' in state_dict_numpy:
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.weight']
else:
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.gamma']
elif name == f'layers.{num}.attn_norm.bias':
if f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.bias' in state_dict_numpy:
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.bias']
else:
return state_dict_numpy[f'{prefix}encoder.layer.{num}.attention.output.LayerNorm.beta']
elif name == f'layers.{num}.ff.linear_hidden.0.weight':
return state_dict_numpy[f'{prefix}encoder.layer.{num}.intermediate.dense.weight']
elif name == f'layers.{num}.ff.linear_hidden.0.bias':
return state_dict_numpy[f'{prefix}encoder.layer.{num}.intermediate.dense.bias']
elif name == f'layers.{num}.ff.linear_out.weight':
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.dense.weight']
elif name == f'layers.{num}.ff.linear_out.bias':
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.dense.bias']
elif name == f'layers.{num}.ff_norm.weight':
if f'{prefix}encoder.layer.{num}.output.LayerNorm.weight' in state_dict_numpy:
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.LayerNorm.weight']
else:
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.LayerNorm.gamma']
elif name == f'layers.{num}.ff_norm.bias':
if f'{prefix}encoder.layer.{num}.output.LayerNorm.bias' in state_dict_numpy:
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.LayerNorm.bias']
else:
return state_dict_numpy[f'{prefix}encoder.layer.{num}.output.LayerNorm.beta']
else:
raise ValueError(f'ERROR: Param "{name}" can not be loaded in Space Model!')
def hug2space(input_file, input_template, output_file):
state_dict_output = OrderedDict()
state_dict_input = torch.load(input_file, map_location=lambda storage, loc: storage)
state_dict_template = torch.load(input_template, map_location=lambda storage, loc: storage)
prefix = 'bert.' if list(state_dict_input.keys())[0].startswith('bert.') else ''
for name, value in state_dict_template.items():
match_value = get_match_value(name, state_dict_input, prefix)
if match_value is not None:
assert match_value.ndim == value.ndim
if match_value.shape != value.shape:
assert value.size(0) == BERT_VOCAB_SIZE and match_value.size(0) > BERT_VOCAB_SIZE
match_value = match_value[:BERT_VOCAB_SIZE]
dtype = value.dtype
device = value.device
state_dict_output[name] = torch.tensor(match_value, dtype=dtype, device=device)
else:
print(f'WARNING: Param "{name}" can not be loaded in Space Model.')
torch.save(state_dict_output, output_file)
if __name__ == '__main__':
# Using Example
input_template = '../model/template.model'
input_file = '../model/bert-base-uncased/pytorch_model.bin'
output_file = '../model/bert-base-uncased.model'
hug2space(input_file=input_file, input_template=input_template, output_file=output_file)
print(f'Converted!')