mirror of https://github.com/microsoft/autogen.git
fix: tool calling cohere (#3355)
* Add support for tool calling cohere * update tool calling code * make client name configurable with default * formatting nits * update docs --------- Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com> Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
parent
2e77d3bc37
commit
5861bd92a6
|
@ -6,6 +6,7 @@ Example:
|
|||
"api_type": "cohere",
|
||||
"model": "command-r-plus",
|
||||
"api_key": os.environ.get("COHERE_API_KEY")
|
||||
"client_name": "autogen-cohere", # Optional parameter
|
||||
}
|
||||
]}
|
||||
|
||||
|
@ -144,7 +145,7 @@ class CohereClient:
|
|||
def create(self, params: Dict) -> ChatCompletion:
|
||||
|
||||
messages = params.get("messages", [])
|
||||
|
||||
client_name = params.get("client_name") or "autogen-cohere"
|
||||
# Parse parameters to the Cohere API's parameters
|
||||
cohere_params = self.parse_params(params)
|
||||
|
||||
|
@ -156,7 +157,7 @@ class CohereClient:
|
|||
cohere_params["preamble"] = preamble
|
||||
|
||||
# We use chat model by default
|
||||
client = Cohere(api_key=self.api_key)
|
||||
client = Cohere(api_key=self.api_key, client_name=client_name)
|
||||
|
||||
# Token counts will be returned
|
||||
prompt_tokens = 0
|
||||
|
@ -285,6 +286,23 @@ class CohereClient:
|
|||
return response_oai
|
||||
|
||||
|
||||
def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
|
||||
temp_tool_results = []
|
||||
|
||||
for tool_call in all_tool_calls:
|
||||
if tool_call["id"] == tool_call_id:
|
||||
|
||||
call = {
|
||||
"name": tool_call["function"]["name"],
|
||||
"parameters": json.loads(
|
||||
tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
|
||||
),
|
||||
}
|
||||
output = [{"value": content_output}]
|
||||
temp_tool_results.append(ToolResult(call=call, outputs=output))
|
||||
return temp_tool_results
|
||||
|
||||
|
||||
def oai_messages_to_cohere_messages(
|
||||
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
|
||||
) -> tuple[list[dict[str, Any]], str, str]:
|
||||
|
@ -352,7 +370,8 @@ def oai_messages_to_cohere_messages(
|
|||
# 'content' field renamed to 'message'
|
||||
# tools go into tools parameter
|
||||
# tool_results go into tool_results parameter
|
||||
for message in messages:
|
||||
messages_length = len(messages)
|
||||
for index, message in enumerate(messages):
|
||||
|
||||
if "role" in message and message["role"] == "system":
|
||||
# System message
|
||||
|
@ -369,34 +388,34 @@ def oai_messages_to_cohere_messages(
|
|||
new_message = {
|
||||
"role": "CHATBOT",
|
||||
"message": message["content"],
|
||||
# Not including tools in this message, may need to. Testing required.
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": tool_call_.get("function", {}).get("name"),
|
||||
"parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
|
||||
}
|
||||
for tool_call_ in message["tool_calls"]
|
||||
],
|
||||
}
|
||||
|
||||
cohere_messages.append(new_message)
|
||||
elif "role" in message and message["role"] == "tool":
|
||||
if "tool_call_id" in message:
|
||||
# Convert the tool call to a result
|
||||
if not (tool_call_id := message.get("tool_call_id")):
|
||||
continue
|
||||
|
||||
tool_call_id = message["tool_call_id"]
|
||||
content_output = message["content"]
|
||||
# Convert the tool call to a result
|
||||
content_output = message["content"]
|
||||
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
|
||||
if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
|
||||
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
|
||||
# So, we pass it into tool_results.
|
||||
tool_results.extend(tool_results_chat_turn)
|
||||
continue
|
||||
|
||||
# Find the original tool
|
||||
for tool_call in tool_calls:
|
||||
if tool_call["id"] == tool_call_id:
|
||||
else:
|
||||
# If its not the current tool call, we pass it as a tool message in the chat history.
|
||||
new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
|
||||
cohere_messages.append(new_message)
|
||||
|
||||
call = {
|
||||
"name": tool_call["function"]["name"],
|
||||
"parameters": json.loads(
|
||||
tool_call["function"]["arguments"]
|
||||
if not tool_call["function"]["arguments"] == ""
|
||||
else "{}"
|
||||
),
|
||||
}
|
||||
output = [{"value": content_output}]
|
||||
|
||||
tool_results.append(ToolResult(call=call, outputs=output))
|
||||
|
||||
break
|
||||
elif "content" in message and isinstance(message["content"], str):
|
||||
# Standard text message
|
||||
new_message = {
|
||||
|
@ -416,7 +435,7 @@ def oai_messages_to_cohere_messages(
|
|||
# If we're adding tool_results, like we are, the last message can't be a USER message
|
||||
# So, we add a CHATBOT 'continue' message, if so.
|
||||
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
|
||||
if cohere_messages[-1]["role"] == "USER":
|
||||
if cohere_messages[-1]["role"].lower() == "user":
|
||||
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})
|
||||
|
||||
# We return a blank message when we have tool results
|
||||
|
|
|
@ -100,6 +100,7 @@
|
|||
"- seed (null, integer)\n",
|
||||
"- frequency_penalty (number 0..1)\n",
|
||||
"- presence_penalty (number 0..1)\n",
|
||||
"- client_name (null, string)\n",
|
||||
"\n",
|
||||
"Example:\n",
|
||||
"```python\n",
|
||||
|
@ -108,6 +109,7 @@
|
|||
" \"model\": \"command-r\",\n",
|
||||
" \"api_key\": \"your Cohere API Key goes here\",\n",
|
||||
" \"api_type\": \"cohere\",\n",
|
||||
" \"client_name\": \"autogen-cohere\",\n",
|
||||
" \"temperature\": 0.5,\n",
|
||||
" \"p\": 0.2,\n",
|
||||
" \"k\": 100,\n",
|
||||
|
@ -526,7 +528,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
Loading…
Reference in New Issue