fix: tool calling cohere (#3355)

* Add support for tool calling cohere

* update tool calling code

* make client name configurable with default

* formatting nits

* update docs

---------

Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com>
Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
Anirudh31415926535 2024-08-29 02:47:39 +08:00 committed by GitHub
parent 2e77d3bc37
commit 5861bd92a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 26 deletions

View File

@ -6,6 +6,7 @@ Example:
"api_type": "cohere", "api_type": "cohere",
"model": "command-r-plus", "model": "command-r-plus",
"api_key": os.environ.get("COHERE_API_KEY") "api_key": os.environ.get("COHERE_API_KEY")
"client_name": "autogen-cohere", # Optional parameter
} }
]} ]}
@ -144,7 +145,7 @@ class CohereClient:
def create(self, params: Dict) -> ChatCompletion: def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", []) messages = params.get("messages", [])
client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters # Parse parameters to the Cohere API's parameters
cohere_params = self.parse_params(params) cohere_params = self.parse_params(params)
@ -156,7 +157,7 @@ class CohereClient:
cohere_params["preamble"] = preamble cohere_params["preamble"] = preamble
# We use chat model by default # We use chat model by default
client = Cohere(api_key=self.api_key) client = Cohere(api_key=self.api_key, client_name=client_name)
# Token counts will be returned # Token counts will be returned
prompt_tokens = 0 prompt_tokens = 0
@ -285,6 +286,23 @@ class CohereClient:
return response_oai return response_oai
def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
temp_tool_results = []
for tool_call in all_tool_calls:
if tool_call["id"] == tool_call_id:
call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
),
}
output = [{"value": content_output}]
temp_tool_results.append(ToolResult(call=call, outputs=output))
return temp_tool_results
def oai_messages_to_cohere_messages( def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any] messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
) -> tuple[list[dict[str, Any]], str, str]: ) -> tuple[list[dict[str, Any]], str, str]:
@ -352,7 +370,8 @@ def oai_messages_to_cohere_messages(
# 'content' field renamed to 'message' # 'content' field renamed to 'message'
# tools go into tools parameter # tools go into tools parameter
# tool_results go into tool_results parameter # tool_results go into tool_results parameter
for message in messages: messages_length = len(messages)
for index, message in enumerate(messages):
if "role" in message and message["role"] == "system": if "role" in message and message["role"] == "system":
# System message # System message
@ -369,34 +388,34 @@ def oai_messages_to_cohere_messages(
new_message = { new_message = {
"role": "CHATBOT", "role": "CHATBOT",
"message": message["content"], "message": message["content"],
# Not including tools in this message, may need to. Testing required. "tool_calls": [
{
"name": tool_call_.get("function", {}).get("name"),
"parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
}
for tool_call_ in message["tool_calls"]
],
} }
cohere_messages.append(new_message) cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool": elif "role" in message and message["role"] == "tool":
if "tool_call_id" in message: if not (tool_call_id := message.get("tool_call_id")):
# Convert the tool call to a result continue
tool_call_id = message["tool_call_id"] # Convert the tool call to a result
content_output = message["content"] content_output = message["content"]
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
# So, we pass it into tool_results.
tool_results.extend(tool_results_chat_turn)
continue
# Find the original tool else:
for tool_call in tool_calls: # If its not the current tool call, we pass it as a tool message in the chat history.
if tool_call["id"] == tool_call_id: new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
cohere_messages.append(new_message)
call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]
tool_results.append(ToolResult(call=call, outputs=output))
break
elif "content" in message and isinstance(message["content"], str): elif "content" in message and isinstance(message["content"], str):
# Standard text message # Standard text message
new_message = { new_message = {
@ -416,7 +435,7 @@ def oai_messages_to_cohere_messages(
# If we're adding tool_results, like we are, the last message can't be a USER message # If we're adding tool_results, like we are, the last message can't be a USER message
# So, we add a CHATBOT 'continue' message, if so. # So, we add a CHATBOT 'continue' message, if so.
# Changed key from "content" to "message" (jaygdesai/autogen_Jay) # Changed key from "content" to "message" (jaygdesai/autogen_Jay)
if cohere_messages[-1]["role"] == "USER": if cohere_messages[-1]["role"].lower() == "user":
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."}) cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})
# We return a blank message when we have tool results # We return a blank message when we have tool results

View File

@ -100,6 +100,7 @@
"- seed (null, integer)\n", "- seed (null, integer)\n",
"- frequency_penalty (number 0..1)\n", "- frequency_penalty (number 0..1)\n",
"- presence_penalty (number 0..1)\n", "- presence_penalty (number 0..1)\n",
"- client_name (null, string)\n",
"\n", "\n",
"Example:\n", "Example:\n",
"```python\n", "```python\n",
@ -108,6 +109,7 @@
" \"model\": \"command-r\",\n", " \"model\": \"command-r\",\n",
" \"api_key\": \"your Cohere API Key goes here\",\n", " \"api_key\": \"your Cohere API Key goes here\",\n",
" \"api_type\": \"cohere\",\n", " \"api_type\": \"cohere\",\n",
" \"client_name\": \"autogen-cohere\",\n",
" \"temperature\": 0.5,\n", " \"temperature\": 0.5,\n",
" \"p\": 0.2,\n", " \"p\": 0.2,\n",
" \"k\": 100,\n", " \"k\": 100,\n",
@ -526,7 +528,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.9" "version": "3.12.5"
} }
}, },
"nbformat": 4, "nbformat": 4,