fix:Spark's large language model token calculation error #7911 (#8755)

This commit is contained in:
cherryhuahua 2024-09-25 14:51:42 +08:00 committed by GitHub
parent 2328944987
commit d0e0111f88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

View File

@ -309,7 +309,7 @@ class AppRunner:
if not prompt_messages: if not prompt_messages:
prompt_messages = result.prompt_messages prompt_messages = result.prompt_messages
if not usage and result.delta.usage: if result.delta.usage:
usage = result.delta.usage usage = result.delta.usage
if not usage: if not usage:

View File

@ -213,18 +213,21 @@ class SparkLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: llm response chunk generator result :return: llm response chunk generator result
""" """
completion = ""
for index, content in enumerate(client.subscribe()): for index, content in enumerate(client.subscribe()):
if isinstance(content, dict): if isinstance(content, dict):
delta = content["data"] delta = content["data"]
else: else:
delta = content delta = content
completion += delta
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
content=delta or "", content=delta or "",
) )
temp_assistant_prompt_message = AssistantPromptMessage(
content=completion,
)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) completion_tokens = self.get_num_tokens(model, credentials, [temp_assistant_prompt_message])
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)