Catch token count issue while streaming with customized models (#3241)

* Catch token count issue while streaming with customized models

If llama, llava, phi, or some other models are used for streaming (with stream=True), the current design would crash after fetching the response.

A warning is enough in this case, just like the non-streaming use cases.

* Only catch not implemented error

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Beibin Li 2024-09-25 08:14:20 -07:00 committed by GitHub
parent c1289b4da7
commit ece69249e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 1 deletions

View File

@ -279,7 +279,12 @@ class OpenAIClient:
# Prepare the final ChatCompletion object based on the accumulated data
model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
prompt_tokens = count_token(params["messages"], model)
try:
prompt_tokens = count_token(params["messages"], model)
except NotImplementedError as e:
# Catch token calculation error if streaming with customized models.
logger.warning(str(e))
prompt_tokens = 0
response = ChatCompletion(
id=chunk.id,
model=chunk.model,