Improve detokenization performance (#1338)

This commit is contained in:
Antoni Baum 2023-10-13 09:59:07 -07:00 committed by GitHub
parent 6368e777a8
commit ec3b5ce9cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions

View File

@ -81,10 +81,11 @@ def _convert_tokens_to_string_with_added_encoders(
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in tokenizer.all_special_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.added_tokens_encoder:
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
@ -129,7 +130,7 @@ def detokenize_incrementally(
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if not getattr(tokenizer, "added_tokens_encoder", {}):
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(