Fix: #950 ZhipuAiChatModel does not support tools without parameters (#999)

## Context
Fix: #950

## Checklist
Before submitting this PR, please check the following points:
- [x] I have added unit and integration tests for my change
- [x] All unit and integration tests in the module I have added/changed
are green
This commit is contained in:
二毛 2024-05-22 23:33:55 +08:00 committed by GitHub
parent acdefd34b0
commit c27c127912
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 0 deletions

View File

@ -44,6 +44,9 @@ class DefaultZhipuAiHelper {
}
private static Parameters toFunctionParameters(ToolParameters toolParameters) {
if (toolParameters == null) {
return Parameters.builder().build();
}
return Parameters.builder()
.properties(toolParameters.properties())
.required(toolParameters.required())

View File

@ -98,4 +98,52 @@ class ZhipuAiChatModelIT {
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
ToolSpecification currentTime = ToolSpecification.builder()
.name("currentTime")
.description("currentTime")
.build();
@Test
void should_execute_get_current_time_tool_and_then_answer() {
// given
UserMessage userMessage = userMessage("What's the time now?");
List<ToolSpecification> toolSpecifications = singletonList(currentTime);
// when
Response<AiMessage> response = chatModel.generate(singletonList(userMessage), toolSpecifications);
// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("currentTime");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "2024-04-23 12:00:20");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
// when
Response<AiMessage> secondResponse = chatModel.generate(messages);
// then
AiMessage secondAiMessage = secondResponse.content();
assertThat(secondAiMessage.text()).contains("2024-04-23 12:00:20");
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
TokenUsage secondTokenUsage = secondResponse.tokenUsage();
assertThat(secondTokenUsage.totalTokenCount())
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
}

View File

@ -112,4 +112,56 @@ public class ZhipuAiStreamingChatModelIT {
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
ToolSpecification currentTime = ToolSpecification.builder()
.name("currentTime")
.description("currentTime")
.build();
@Test
void should_execute_get_current_time_tool_and_then_answer() {
// given
UserMessage userMessage = userMessage("What's the time now?");
List<ToolSpecification> toolSpecifications = singletonList(currentTime);
// when
TestStreamingResponseHandler<AiMessage> handler = new TestStreamingResponseHandler<>();
model.generate(singletonList(userMessage), toolSpecifications, handler);
// then
Response<AiMessage> response = handler.get();
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("currentTime");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
// given
ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "2024-04-23 12:00:20");
List<ChatMessage> messages = asList(userMessage, aiMessage, toolExecutionResultMessage);
// when
TestStreamingResponseHandler<AiMessage> secondHandler = new TestStreamingResponseHandler<>();
model.generate(messages, secondHandler);
// then
Response<AiMessage> secondResponse = secondHandler.get();
AiMessage secondAiMessage = secondResponse.content();
assertThat(secondAiMessage.text()).contains("2024-04-23 12:00:20");
assertThat(secondAiMessage.toolExecutionRequests()).isNull();
TokenUsage secondTokenUsage = secondResponse.tokenUsage();
assertThat(secondTokenUsage.totalTokenCount())
.isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount());
assertThat(secondResponse.finishReason()).isEqualTo(STOP);
}
}