fix: tool calling cohere (#3355)

* Add support for tool calling cohere

* update tool calling code

* make client name configurable with default

* formatting nits

* update docs

---------

Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com>
Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
Anirudh31415926535 2024-08-29 02:47:39 +08:00 committed by GitHub
parent 2e77d3bc37
commit 5861bd92a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 26 deletions

View File

@ -6,6 +6,7 @@ Example:
"api_type": "cohere",
"model": "command-r-plus",
"api_key": os.environ.get("COHERE_API_KEY")
"client_name": "autogen-cohere", # Optional parameter
}
]}
@ -144,7 +145,7 @@ class CohereClient:
def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters
cohere_params = self.parse_params(params)
@ -156,7 +157,7 @@ class CohereClient:
cohere_params["preamble"] = preamble
# We use chat model by default
client = Cohere(api_key=self.api_key)
client = Cohere(api_key=self.api_key, client_name=client_name)
# Token counts will be returned
prompt_tokens = 0
@ -285,6 +286,23 @@ class CohereClient:
return response_oai
def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
temp_tool_results = []
for tool_call in all_tool_calls:
if tool_call["id"] == tool_call_id:
call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
),
}
output = [{"value": content_output}]
temp_tool_results.append(ToolResult(call=call, outputs=output))
return temp_tool_results
def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
) -> tuple[list[dict[str, Any]], str, str]:
@ -352,7 +370,8 @@ def oai_messages_to_cohere_messages(
# 'content' field renamed to 'message'
# tools go into tools parameter
# tool_results go into tool_results parameter
for message in messages:
messages_length = len(messages)
for index, message in enumerate(messages):
if "role" in message and message["role"] == "system":
# System message
@ -369,34 +388,34 @@ def oai_messages_to_cohere_messages(
new_message = {
"role": "CHATBOT",
"message": message["content"],
# Not including tools in this message, may need to. Testing required.
"tool_calls": [
{
"name": tool_call_.get("function", {}).get("name"),
"parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
}
for tool_call_ in message["tool_calls"]
],
}
cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool":
if "tool_call_id" in message:
# Convert the tool call to a result
if not (tool_call_id := message.get("tool_call_id")):
continue
tool_call_id = message["tool_call_id"]
content_output = message["content"]
# Convert the tool call to a result
content_output = message["content"]
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
# So, we pass it into tool_results.
tool_results.extend(tool_results_chat_turn)
continue
# Find the original tool
for tool_call in tool_calls:
if tool_call["id"] == tool_call_id:
else:
# If its not the current tool call, we pass it as a tool message in the chat history.
new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
cohere_messages.append(new_message)
call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]
tool_results.append(ToolResult(call=call, outputs=output))
break
elif "content" in message and isinstance(message["content"], str):
# Standard text message
new_message = {
@ -416,7 +435,7 @@ def oai_messages_to_cohere_messages(
# If we're adding tool_results, like we are, the last message can't be a USER message
# So, we add a CHATBOT 'continue' message, if so.
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
if cohere_messages[-1]["role"] == "USER":
if cohere_messages[-1]["role"].lower() == "user":
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})
# We return a blank message when we have tool results

View File

@ -100,6 +100,7 @@
"- seed (null, integer)\n",
"- frequency_penalty (number 0..1)\n",
"- presence_penalty (number 0..1)\n",
"- client_name (null, string)\n",
"\n",
"Example:\n",
"```python\n",
@ -108,6 +109,7 @@
" \"model\": \"command-r\",\n",
" \"api_key\": \"your Cohere API Key goes here\",\n",
" \"api_type\": \"cohere\",\n",
" \"client_name\": \"autogen-cohere\",\n",
" \"temperature\": 0.5,\n",
" \"p\": 0.2,\n",
" \"k\": 100,\n",
@ -526,7 +528,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.5"
}
},
"nbformat": 4,