fix Ollama call tool many times (#1778)

## Issue
Closes #1777 

## Change
`langchain4j-ollama` Add `toolCalls` fields when ChatMessage's type is
AiMessage.

The original message do not contains the `toolCalls`:

```json
{
    "model": "qwen2:7b-instruct-q2_K",
    "messages": [
        {
            "role": "user",
            "content": "What is  3+4?"
        },
        {
            "role": "assistant"
        },
        {
            "role": "tool",
            "content": "7"
        }
    ],
    "options": {},
    "stream": false,
    "tools": [
        {
            "type": "function",
            "function": {
                "name": "add",
                "description": "",
                "parameters": {
                    "properties": {
                        "arg1": {
                            "type": "integer"
                        },
                        "arg0": {
                            "type": "integer"
                        }
                    },
                    "required": [
                        "arg0",
                        "arg1"
                    ]
                }
            }
        }
    ]
}
```

The expect way is that the `AiMessage` should contains the toolCall
information so that the model will not continue to call the tools.

expect request:

```json
{
  "model" : "qwen2:7b-instruct-q2_K",
  "messages" : [ {
    "role" : "user",
    "content" : "What is  3+4?"
  }, {
    "role" : "assistant",
    "tool_calls" : [ {
      "function" : {
        "name" : "add",
        "arguments" : {
          "arg1" : 4,
          "arg0" : 3
        }
      }
    } ]
  }, {
    "role" : "tool",
    "content" : "7"
  } ],
  "options" : { },
  "stream" : false,
  "tools" : [ {
    "type" : "function",
    "function" : {
      "name" : "add",
      "description" : "",
      "parameters" : {
        "type" : "object",
        "properties" : {
          "arg1" : {
            "type" : "integer"
          },
          "arg0" : {
            "type" : "integer"
          }
        },
        "required" : [ "arg0", "arg1" ]
      }
    }
  } ]
}
```

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [ ] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
ZYinNJU 2024-09-17 15:48:55 +08:00 committed by GitHub
parent b97492773a
commit 558433d203
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 24 additions and 0 deletions

View File

@ -1,17 +1,21 @@
package dev.langchain4j.model.ollama; package dev.langchain4j.model.ollama;
import com.fasterxml.jackson.core.type.TypeReference;
import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*; import dev.langchain4j.data.message.*;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static dev.langchain4j.data.message.ContentType.IMAGE; import static dev.langchain4j.data.message.ContentType.IMAGE;
import static dev.langchain4j.data.message.ContentType.TEXT; import static dev.langchain4j.data.message.ContentType.TEXT;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.toJson; import static dev.langchain4j.model.ollama.OllamaJsonUtils.toJson;
import static dev.langchain4j.model.ollama.OllamaJsonUtils.toObject;
class OllamaMessagesUtils { class OllamaMessagesUtils {
@ -75,9 +79,29 @@ class OllamaMessagesUtils {
} }
private static Message otherMessages(ChatMessage chatMessage) { private static Message otherMessages(ChatMessage chatMessage) {
List<ToolCall> toolCalls = null;
if (ChatMessageType.AI == chatMessage.type()) {
AiMessage aiMessage = (AiMessage) chatMessage;
List<ToolExecutionRequest> toolExecutionRequests = aiMessage.toolExecutionRequests();
toolCalls = Optional.ofNullable(toolExecutionRequests)
.map(reqs -> reqs.stream()
.map(toolExecutionRequest -> {
TypeReference<HashMap<String, Object>> typeReference = new TypeReference<HashMap<String, Object>>() {
};
FunctionCall functionCall = FunctionCall.builder()
.name(toolExecutionRequest.name())
.arguments(toObject(toolExecutionRequest.arguments(), typeReference))
.build();
return ToolCall.builder()
.function(functionCall).build();
}).collect(Collectors.toList()))
.orElse(null);
}
return Message.builder() return Message.builder()
.role(toOllamaRole(chatMessage.type())) .role(toOllamaRole(chatMessage.type()))
.content(chatMessage.text()) .content(chatMessage.text())
.toolCalls(toolCalls)
.build(); .build();
} }