mirror of https://github.com/microsoft/autogen.git
Function Calling Support for Gemini - Part 2 (#3726)
* Added function calling support to GeminiClient * Appending a continue message from model to alternate between user and model * Fixed cost calculation to include <128K costing and new 1.5-flash model. Added test case for function_call. * Add notebook with samples for Gemini * Updated test case * Fix to handle not dict response in function call * Handle non dict function results and append dummy model message between function response and user message * Fixing message order in gemini * Append text as multiple parts instead of test concatenation * Raising error for Union data types in function parameter * Delete default key * Update gemini.py for multiple tool calls + pre-commit formatting * no function role * start adding function calling config * do not serialize tool_config * improve tool config parsing * add hint * improve function calling config * removunnecessary comments * try removing allowed function names in tool config conversion * fix tool config parsing with empty tools list * improve logging and case handling with vertexai tool config parsing * reset file * check if text is in part * fix empty part checking case * fix bug with attribute handling * skip test if gemini deps are not installed --------- Co-authored-by: Arjun G <arjun@arjun-g.com> Co-authored-by: Beibin Li <BeibinLi@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com>
This commit is contained in:
parent
3ebd7aeec2
commit
32022b2df6
|
@ -32,6 +32,8 @@ Resources:
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
@ -39,24 +41,39 @@ import re
|
|||
import time
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Mapping, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
import vertexai
|
||||
from google.ai.generativelanguage import Content, Part
|
||||
from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool
|
||||
from google.api_core.exceptions import InternalServerError
|
||||
from google.auth.credentials import Credentials
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from PIL import Image
|
||||
from vertexai.generative_models import Content as VertexAIContent
|
||||
from vertexai.generative_models import (
|
||||
Content as VertexAIContent,
|
||||
)
|
||||
from vertexai.generative_models import (
|
||||
FunctionDeclaration as VertexAIFunctionDeclaration,
|
||||
)
|
||||
from vertexai.generative_models import (
|
||||
GenerationConfig as VertexAIGenerationConfig,
|
||||
)
|
||||
from vertexai.generative_models import GenerativeModel
|
||||
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
|
||||
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
|
||||
from vertexai.generative_models import Part as VertexAIPart
|
||||
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
|
||||
from vertexai.generative_models import (
|
||||
Tool as VertexAITool,
|
||||
)
|
||||
from vertexai.generative_models import (
|
||||
ToolConfig as VertexAIToolConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -107,7 +124,7 @@ class GeminiClient:
|
|||
|
||||
Args:
|
||||
api_key (str): The API key for using Gemini.
|
||||
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
|
||||
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
|
||||
google_application_credentials (str): Path to the JSON service account key file of the service account.
|
||||
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
||||
can also be set instead of using this argument.
|
||||
|
@ -171,6 +188,8 @@ class GeminiClient:
|
|||
|
||||
params.get("api_type", "google") # not used
|
||||
messages = params.get("messages", [])
|
||||
tools = params.get("tools", [])
|
||||
tool_config = params.get("tool_config", {})
|
||||
stream = params.get("stream", False)
|
||||
n_response = params.get("n", 1)
|
||||
system_instruction = params.get("system_instruction", None)
|
||||
|
@ -183,6 +202,7 @@ class GeminiClient:
|
|||
}
|
||||
if self.use_vertexai:
|
||||
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
|
||||
tool_config = GeminiClient._to_vertexai_tool_config(tool_config, tools)
|
||||
else:
|
||||
safety_settings = params.get("safety_settings", {})
|
||||
|
||||
|
@ -198,12 +218,15 @@ class GeminiClient:
|
|||
if "vision" not in model_name:
|
||||
# A. create and call the chat model.
|
||||
gemini_messages = self._oai_messages_to_gemini_messages(messages)
|
||||
gemini_tools = self._oai_tools_to_gemini_tools(tools)
|
||||
if self.use_vertexai:
|
||||
model = GenerativeModel(
|
||||
model_name,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
system_instruction=system_instruction,
|
||||
tools=gemini_tools,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
|
||||
else:
|
||||
|
@ -213,12 +236,13 @@ class GeminiClient:
|
|||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
system_instruction=system_instruction,
|
||||
tools=gemini_tools,
|
||||
)
|
||||
genai.configure(api_key=self.api_key)
|
||||
chat = model.start_chat(history=gemini_messages[:-1])
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
ans = None
|
||||
ans: Union[Content, VertexAIContent] = None
|
||||
try:
|
||||
response = chat.send_message(
|
||||
gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
|
||||
|
@ -234,7 +258,7 @@ class GeminiClient:
|
|||
raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}")
|
||||
else:
|
||||
# `ans = response.text` is unstable. Use the following code instead.
|
||||
ans: str = chat.history[-1].parts[0].text
|
||||
ans: Union[Content, VertexAIContent] = chat.history[-1]
|
||||
break
|
||||
|
||||
if ans is None:
|
||||
|
@ -262,7 +286,7 @@ class GeminiClient:
|
|||
# Gemini's vision model does not support chat history yet
|
||||
# chat = model.start_chat(history=gemini_messages[:-1])
|
||||
# response = chat.send_message(gemini_messages[-1].parts)
|
||||
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
|
||||
user_message = self._oai_content_to_gemini_content(messages[-1])
|
||||
if len(messages) > 2:
|
||||
warnings.warn(
|
||||
"Warning: Gemini's vision model does not support chat history yet.",
|
||||
|
@ -273,16 +297,14 @@ class GeminiClient:
|
|||
response = model.generate_content(user_message, stream=stream)
|
||||
# ans = response.text
|
||||
if self.use_vertexai:
|
||||
ans: str = response.candidates[0].content.parts[0].text
|
||||
ans: VertexAIContent = response.candidates[0].content
|
||||
else:
|
||||
ans: str = response._result.candidates[0].content.parts[0].text
|
||||
ans: Content = response._result.candidates[0].content
|
||||
|
||||
prompt_tokens = model.count_tokens(user_message).total_tokens
|
||||
completion_tokens = model.count_tokens(ans).total_tokens
|
||||
completion_tokens = model.count_tokens(ans.parts[0].text).total_tokens
|
||||
|
||||
# 3. convert output
|
||||
message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
|
||||
choices = [Choice(finish_reason="stop", index=0, message=message)]
|
||||
choices = self._gemini_content_to_oai_choices(ans)
|
||||
|
||||
response_oai = ChatCompletion(
|
||||
id=str(random.randint(0, 1000)),
|
||||
|
@ -295,31 +317,87 @@ class GeminiClient:
|
|||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
|
||||
cost=self._calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
|
||||
)
|
||||
|
||||
return response_oai
|
||||
|
||||
def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
|
||||
# If str is not a json string return str as is
|
||||
def _to_json(self, str) -> dict:
|
||||
try:
|
||||
return json.loads(str)
|
||||
except ValueError:
|
||||
return str
|
||||
|
||||
def _oai_content_to_gemini_content(self, message: Dict[str, Any]) -> List:
|
||||
"""Convert content from OAI format to Gemini format"""
|
||||
rst = []
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
content = "empty" # Empty content is not allowed.
|
||||
if isinstance(message["content"], str):
|
||||
if message["content"] == "":
|
||||
message["content"] = "empty" # Empty content is not allowed.
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIPart.from_text(content))
|
||||
rst.append(VertexAIPart.from_text(message["content"]))
|
||||
else:
|
||||
rst.append(Part(text=content))
|
||||
rst.append(Part(text=message["content"]))
|
||||
return rst
|
||||
|
||||
assert isinstance(content, list)
|
||||
if "tool_calls" in message:
|
||||
if self.use_vertexai:
|
||||
for tool_call in message["tool_calls"]:
|
||||
rst.append(
|
||||
VertexAIPart.from_dict(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": json.loads(tool_call["function"]["arguments"]),
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
else:
|
||||
for tool_call in message["tool_calls"]:
|
||||
rst.append(
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name=tool_call["function"]["name"],
|
||||
args=json.loads(tool_call["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
return rst
|
||||
|
||||
for msg in content:
|
||||
if message["role"] == "tool":
|
||||
if self.use_vertexai:
|
||||
rst.append(
|
||||
VertexAIPart.from_function_response(
|
||||
name=message["name"], response={"result": self._to_json(message["content"])}
|
||||
)
|
||||
)
|
||||
else:
|
||||
rst.append(
|
||||
Part(
|
||||
function_response=FunctionResponse(
|
||||
name=message["name"], response={"result": self._to_json(message["content"])}
|
||||
)
|
||||
)
|
||||
)
|
||||
return rst
|
||||
|
||||
if isinstance(message["content"], str):
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIPart.from_text(message["content"]))
|
||||
else:
|
||||
rst.append(Part(text=message["content"]))
|
||||
return rst
|
||||
|
||||
assert isinstance(message["content"], list)
|
||||
|
||||
for msg in message["content"]:
|
||||
if isinstance(msg, dict):
|
||||
assert "type" in msg, f"Missing 'type' field in message: {msg}"
|
||||
if msg["type"] == "text":
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIPart.from_text(text=msg["text"]))
|
||||
rst.append(VertexAIPart.from_text(msg["text"]))
|
||||
else:
|
||||
rst.append(Part(text=msg["text"]))
|
||||
elif msg["type"] == "image_url":
|
||||
|
@ -340,34 +418,32 @@ class GeminiClient:
|
|||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||
return rst
|
||||
|
||||
def _concat_parts(self, parts: List[Part]) -> List:
|
||||
"""Concatenate parts with the same type.
|
||||
If two adjacent parts both have the "text" attribute, then it will be joined into one part.
|
||||
"""
|
||||
if not parts:
|
||||
return []
|
||||
def _calculate_gemini_cost(self, input_tokens: int, output_tokens: int, model_name: str) -> float:
|
||||
if "1.5-pro" in model_name:
|
||||
if (input_tokens + output_tokens) <= 128000:
|
||||
# "gemini-1.5-pro"
|
||||
# When total tokens is less than 128K cost is $3.5 per million input tokens and $10.5 per million output tokens
|
||||
return 3.5 * input_tokens / 1e6 + 10.5 * output_tokens / 1e6
|
||||
# "gemini-1.5-pro"
|
||||
# Cost is $7 per million input tokens and $21 per million output tokens
|
||||
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
|
||||
|
||||
concatenated_parts = []
|
||||
previous_part = parts[0]
|
||||
if "1.5-flash" in model_name:
|
||||
if (input_tokens + output_tokens) <= 128000:
|
||||
# "gemini-1.5-flash"
|
||||
# Cost is $0.35 per million input tokens and $1.05 per million output tokens
|
||||
return 0.35 * input_tokens / 1e6 + 1.05 * output_tokens / 1e6
|
||||
# "gemini-1.5-flash"
|
||||
# When total tokens is less than 128K cost is $0.70 per million input tokens and $2.10 per million output tokens
|
||||
return 0.70 * input_tokens / 1e6 + 2.10 * output_tokens / 1e6
|
||||
|
||||
for current_part in parts[1:]:
|
||||
if previous_part.text != "":
|
||||
if self.use_vertexai:
|
||||
previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
|
||||
else:
|
||||
previous_part.text += current_part.text
|
||||
else:
|
||||
concatenated_parts.append(previous_part)
|
||||
previous_part = current_part
|
||||
if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
|
||||
warnings.warn(
|
||||
f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning
|
||||
)
|
||||
|
||||
if previous_part.text == "":
|
||||
if self.use_vertexai:
|
||||
previous_part = VertexAIPart.from_text("empty")
|
||||
else:
|
||||
previous_part.text = "empty" # Empty content is not allowed.
|
||||
concatenated_parts.append(previous_part)
|
||||
|
||||
return concatenated_parts
|
||||
# Cost is $0.5 per million input tokens and $1.5 per million output tokens
|
||||
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
|
||||
|
||||
def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert messages from OAI format to Gemini format.
|
||||
|
@ -376,38 +452,154 @@ class GeminiClient:
|
|||
"""
|
||||
prev_role = None
|
||||
rst = []
|
||||
curr_parts = []
|
||||
for i, message in enumerate(messages):
|
||||
parts = self._oai_content_to_gemini_content(message["content"])
|
||||
role = "user" if message["role"] in ["user", "system"] else "model"
|
||||
if (prev_role is None) or (role == prev_role):
|
||||
curr_parts += parts
|
||||
elif role != prev_role:
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
|
||||
else:
|
||||
rst.append(Content(parts=curr_parts, role=prev_role))
|
||||
curr_parts = parts
|
||||
prev_role = role
|
||||
|
||||
# handle the last message
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=curr_parts, role=role))
|
||||
else:
|
||||
rst.append(Content(parts=curr_parts, role=role))
|
||||
def append_parts(parts, role):
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=parts, role=role))
|
||||
else:
|
||||
rst.append(Content(parts=parts, role=role))
|
||||
|
||||
def append_text_to_last(text):
|
||||
if self.use_vertexai:
|
||||
rst[-1] = VertexAIContent(parts=[*rst[-1].parts, VertexAIPart.from_text(text)], role=rst[-1].role)
|
||||
else:
|
||||
rst[-1] = Content(parts=[*rst[-1].parts, Part(text=text)], role=rst[-1].role)
|
||||
|
||||
def is_function_call(parts):
|
||||
return self.use_vertexai and parts[0].function_call or not self.use_vertexai and "function_call" in parts[0]
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
|
||||
# Since the tool call message does not have the "name" field, we need to find the corresponding tool message.
|
||||
if message["role"] == "tool":
|
||||
message["name"] = [
|
||||
m["tool_calls"][i]["function"]["name"]
|
||||
for m in messages
|
||||
if "tool_calls" in m
|
||||
for i, tc in enumerate(m["tool_calls"])
|
||||
if tc["id"] == message["tool_call_id"]
|
||||
][0]
|
||||
|
||||
parts = self._oai_content_to_gemini_content(message)
|
||||
role = "user" if message["role"] in ["user", "system"] else "model"
|
||||
|
||||
# In Gemini if the current message is a function call then previous message should not be a model message.
|
||||
if is_function_call(parts):
|
||||
# If the previous message is a model message then add a dummy "continue" user message before the function call
|
||||
if prev_role == "model":
|
||||
append_parts(self._oai_content_to_gemini_content({"content": "continue"}), "user")
|
||||
append_parts(parts, role)
|
||||
# In Gemini if the current message is a function response then next message should be a model message.
|
||||
elif role == "function":
|
||||
append_parts(parts, "function")
|
||||
# If the next message is not a model message then add a dummy "continue" model message after the function response
|
||||
if len(messages) > (i + 1) and messages[i + 1]["role"] in ["user", "system"]:
|
||||
append_parts(self._oai_content_to_gemini_content({"content": "continue"}), "model")
|
||||
# If the role is the same as the previous role and both are text messages then concatenate the text
|
||||
elif role == prev_role:
|
||||
append_text_to_last(parts[0].text)
|
||||
# If this is first message or the role is different from the previous role then append the parts
|
||||
else:
|
||||
# If the previous text message is empty then update the text to "empty" as Gemini does not support empty messages
|
||||
if (
|
||||
(len(rst) > 0)
|
||||
and hasattr(rst[-1].parts[0], "_raw_part")
|
||||
and hasattr(rst[-1].parts[0]._raw_part, "text")
|
||||
and (rst[-1].parts[0]._raw_part.text == "")
|
||||
):
|
||||
append_text_to_last("empty")
|
||||
append_parts(parts, role)
|
||||
|
||||
prev_role = role
|
||||
|
||||
# The Gemini is restrict on order of roles, such that
|
||||
# 1. The messages should be interleaved between user and model.
|
||||
# 2. The last message must be from the user role.
|
||||
# We add a dummy message "continue" if the last role is not the user.
|
||||
if rst[-1].role != "user":
|
||||
if rst[-1].role != "user" and rst[-1].role != "function":
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user"))
|
||||
rst.append(
|
||||
VertexAIContent(parts=self._oai_content_to_gemini_content({"content": "continue"}), role="user")
|
||||
)
|
||||
else:
|
||||
rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user"))
|
||||
|
||||
rst.append(Content(parts=self._oai_content_to_gemini_content({"content": "continue"}), role="user"))
|
||||
return rst
|
||||
|
||||
def _oai_tools_to_gemini_tools(self, tools: List[Dict[str, Any]]) -> List[Tool]:
|
||||
"""Convert tools from OAI format to Gemini format."""
|
||||
if len(tools) == 0:
|
||||
return None
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if self.use_vertexai:
|
||||
function_declaration = VertexAIFunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"]["description"],
|
||||
parameters=tool["function"]["parameters"],
|
||||
)
|
||||
else:
|
||||
function_declaration = FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"]["description"],
|
||||
parameters=self._oai_function_parameters_to_gemini_function_parameters(
|
||||
copy.deepcopy(tool["function"]["parameters"])
|
||||
),
|
||||
)
|
||||
function_declarations.append(function_declaration)
|
||||
if self.use_vertexai:
|
||||
return [VertexAITool(function_declarations=function_declarations)]
|
||||
else:
|
||||
return [Tool(function_declarations=function_declarations)]
|
||||
|
||||
def _oai_function_parameters_to_gemini_function_parameters(
|
||||
self, function_definition: dict[str, any]
|
||||
) -> dict[str, any]:
|
||||
"""
|
||||
Convert OpenAPI function definition parameters to Gemini function parameters definition.
|
||||
The type key is renamed to type_ and the value is capitalized.
|
||||
"""
|
||||
assert "anyOf" not in function_definition, "Union types are not supported for function parameter in Gemini."
|
||||
# Delete the default key as it is not supported in Gemini
|
||||
if "default" in function_definition:
|
||||
del function_definition["default"]
|
||||
|
||||
function_definition["type_"] = function_definition["type"].upper()
|
||||
del function_definition["type"]
|
||||
if "properties" in function_definition:
|
||||
for key in function_definition["properties"]:
|
||||
function_definition["properties"][key] = self._oai_function_parameters_to_gemini_function_parameters(
|
||||
function_definition["properties"][key]
|
||||
)
|
||||
if "items" in function_definition:
|
||||
function_definition["items"] = self._oai_function_parameters_to_gemini_function_parameters(
|
||||
function_definition["items"]
|
||||
)
|
||||
return function_definition
|
||||
|
||||
def _gemini_content_to_oai_choices(self, response: Union[Content, VertexAIContent]) -> List[Choice]:
|
||||
"""Convert response from Gemini format to OAI format."""
|
||||
text = None
|
||||
tool_calls = []
|
||||
for part in response.parts:
|
||||
if part.function_call:
|
||||
if self.use_vertexai:
|
||||
arguments = VertexAIPart.to_dict(part)["function_call"]["args"]
|
||||
else:
|
||||
arguments = Part.to_dict(part)["function_call"]["args"]
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=str(random.randint(0, 1000)),
|
||||
type="function",
|
||||
function=Function(name=part.function_call.name, arguments=json.dumps(arguments)),
|
||||
)
|
||||
)
|
||||
elif part.text:
|
||||
text = part.text
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=text, function_call=None, tool_calls=tool_calls if len(tool_calls) > 0 else None
|
||||
)
|
||||
return [Choice(finish_reason="tool_calls" if tool_calls else "stop", index=0, message=message)]
|
||||
|
||||
@staticmethod
|
||||
def _to_vertexai_safety_settings(safety_settings):
|
||||
"""Convert safety settings to VertexAI format if needed,
|
||||
|
@ -437,6 +629,49 @@ class GeminiClient:
|
|||
else:
|
||||
return safety_settings
|
||||
|
||||
@staticmethod
|
||||
def _to_vertexai_tool_config(tool_config, tools):
|
||||
"""Convert tool config to VertexAI format,
|
||||
like when specifying them in the OAI_CONFIG_LIST
|
||||
"""
|
||||
if (
|
||||
isinstance(tool_config, dict)
|
||||
and (len(tool_config) > 0)
|
||||
and all([isinstance(tool_config[tool_config_entry], dict) for tool_config_entry in tool_config])
|
||||
):
|
||||
if (
|
||||
tool_config["function_calling_config"]["mode"]
|
||||
not in VertexAIToolConfig.FunctionCallingConfig.Mode.__members__
|
||||
):
|
||||
invalid_mode = tool_config["function_calling_config"]
|
||||
logger.error(f"Function calling mode {invalid_mode} is invalid")
|
||||
return None
|
||||
else:
|
||||
# Currently, there is only function calling config
|
||||
func_calling_config_params = {}
|
||||
func_calling_config_params["mode"] = VertexAIToolConfig.FunctionCallingConfig.Mode[
|
||||
tool_config["function_calling_config"]["mode"]
|
||||
]
|
||||
if (
|
||||
(func_calling_config_params["mode"] == VertexAIToolConfig.FunctionCallingConfig.Mode.ANY)
|
||||
and (len(tools) > 0)
|
||||
and all(["function_name" in tool for tool in tools])
|
||||
):
|
||||
# The function names are not yet known when parsing the OAI_CONFIG_LIST
|
||||
func_calling_config_params["allowed_function_names"] = [tool["function_name"] for tool in tools]
|
||||
vertexai_tool_config = VertexAIToolConfig(
|
||||
function_calling_config=VertexAIToolConfig.FunctionCallingConfig(**func_calling_config_params)
|
||||
)
|
||||
return vertexai_tool_config
|
||||
elif isinstance(tool_config, VertexAIToolConfig):
|
||||
return tool_config
|
||||
elif len(tool_config) == 0 and len(tools) == 0:
|
||||
logger.debug("VertexAI tool config is empty!")
|
||||
return None
|
||||
else:
|
||||
logger.error("Invalid VertexAI tool config!")
|
||||
return None
|
||||
|
||||
|
||||
def _to_pil(data: str) -> Image.Image:
|
||||
"""
|
||||
|
@ -470,16 +705,3 @@ def get_image_data(image_file: str, use_b64=True) -> bytes:
|
|||
return base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
return content
|
||||
|
||||
|
||||
def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
|
||||
if "1.5" in model_name or "gemini-experimental" in model_name:
|
||||
# "gemini-1.5-pro-preview-0409"
|
||||
# Cost is $7 per million input tokens and $21 per million output tokens
|
||||
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
|
||||
|
||||
if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
|
||||
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
|
||||
|
||||
# Cost is $0.5 per million input tokens and $1.5 per million output tokens
|
||||
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
|
||||
|
|
|
@ -21,6 +21,7 @@ NON_CACHE_KEY = [
|
|||
"azure_ad_token",
|
||||
"azure_ad_token_provider",
|
||||
"credentials",
|
||||
"tool_config",
|
||||
]
|
||||
DEFAULT_AZURE_API_VERSION = "2024-02-01"
|
||||
OAI_PRICE1K = {
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
@ -10,7 +11,9 @@ try:
|
|||
from google.cloud.aiplatform.initializer import global_config as vertexai_global_config
|
||||
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
|
||||
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
|
||||
from vertexai.generative_models import Part as VertexAIPart
|
||||
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
|
||||
from vertexai.generative_models import ToolConfig as VertexAIToolConfig
|
||||
|
||||
from autogen.oai.gemini import GeminiClient
|
||||
|
||||
|
@ -20,6 +23,8 @@ except ImportError:
|
|||
VertexAIHarmBlockThreshold = object
|
||||
VertexAIHarmCategory = object
|
||||
VertexAISafetySetting = object
|
||||
VertexAIPart = object
|
||||
VertexAIToolConfig = object
|
||||
vertexai_global_config = object
|
||||
InternalServerError = object
|
||||
skip = True
|
||||
|
@ -234,8 +239,6 @@ def test_vertexai_safety_setting_list(gemini_client):
|
|||
for category in harm_categories
|
||||
]
|
||||
|
||||
print(safety_settings)
|
||||
|
||||
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
|
||||
|
||||
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
|
||||
|
@ -250,6 +253,59 @@ def test_vertexai_safety_setting_list(gemini_client):
|
|||
assert all(settings_comparison), "Converted safety settings are incorrect"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_vertexai_tool_config(gemini_client):
|
||||
|
||||
tools = [{"function_name": "calculator"}]
|
||||
|
||||
tool_config = {"function_calling_config": {"mode": "ANY"}}
|
||||
|
||||
expected_tool_config = VertexAIToolConfig(
|
||||
function_calling_config=VertexAIToolConfig.FunctionCallingConfig(
|
||||
mode=VertexAIToolConfig.FunctionCallingConfig.Mode.ANY,
|
||||
allowed_function_names=["calculator"],
|
||||
)
|
||||
)
|
||||
|
||||
converted_tool_config = GeminiClient._to_vertexai_tool_config(tool_config, tools)
|
||||
|
||||
converted_mode = converted_tool_config._gapic_tool_config.function_calling_config.mode
|
||||
expected_mode = expected_tool_config._gapic_tool_config.function_calling_config.mode
|
||||
converted_allowed_func = converted_tool_config._gapic_tool_config.function_calling_config.allowed_function_names
|
||||
expected_allowed_func = expected_tool_config._gapic_tool_config.function_calling_config.allowed_function_names
|
||||
|
||||
assert converted_mode == expected_mode, "Function calling mode is not converted correctly"
|
||||
assert (
|
||||
converted_allowed_func == expected_allowed_func
|
||||
), "Function calling allowed function names is not converted correctly"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_vertexai_tool_config_no_functions(gemini_client):
|
||||
|
||||
tools = []
|
||||
|
||||
tool_config = {"function_calling_config": {"mode": "ANY"}}
|
||||
|
||||
expected_tool_config = VertexAIToolConfig(
|
||||
function_calling_config=VertexAIToolConfig.FunctionCallingConfig(
|
||||
mode=VertexAIToolConfig.FunctionCallingConfig.Mode.ANY,
|
||||
)
|
||||
)
|
||||
|
||||
converted_tool_config = GeminiClient._to_vertexai_tool_config(tool_config, tools)
|
||||
|
||||
converted_mode = converted_tool_config._gapic_tool_config.function_calling_config.mode
|
||||
expected_mode = expected_tool_config._gapic_tool_config.function_calling_config.mode
|
||||
converted_allowed_func = converted_tool_config._gapic_tool_config.function_calling_config.allowed_function_names
|
||||
expected_allowed_func = expected_tool_config._gapic_tool_config.function_calling_config.allowed_function_names
|
||||
|
||||
assert converted_mode == expected_mode, "Function calling mode is not converted correctly"
|
||||
assert (
|
||||
converted_allowed_func == expected_allowed_func
|
||||
), "Function calling allowed function names is not converted correctly"
|
||||
|
||||
|
||||
# Test error handling
|
||||
@patch("autogen.oai.gemini.genai")
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
|
@ -279,9 +335,10 @@ def test_cost_calculation(gemini_client, mock_response):
|
|||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.Content")
|
||||
@patch("autogen.oai.gemini.genai.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.genai.configure")
|
||||
def test_create_response(mock_configure, mock_generative_model, gemini_client):
|
||||
def test_create_response(mock_configure, mock_generative_model, mock_content, gemini_client):
|
||||
# Mock the genai model configuration and creation process
|
||||
mock_chat = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
|
@ -292,6 +349,8 @@ def test_create_response(mock_configure, mock_generative_model, gemini_client):
|
|||
# Set up a mock for the chat history item access and the text attribute return
|
||||
mock_history_part = MagicMock()
|
||||
mock_history_part.text = "Example response"
|
||||
mock_history_part.function_call = None
|
||||
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
|
||||
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
|
||||
|
||||
# Setup the mock to return a mocked chat response
|
||||
|
@ -306,6 +365,55 @@ def test_create_response(mock_configure, mock_generative_model, gemini_client):
|
|||
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.Part")
|
||||
@patch("autogen.oai.gemini.Content")
|
||||
@patch("autogen.oai.gemini.genai.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.genai.configure")
|
||||
def test_create_function_call_response(mock_configure, mock_generative_model, mock_content, mock_part, gemini_client):
|
||||
# Mock the genai model configuration and creation process
|
||||
mock_chat = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_configure.return_value = None
|
||||
mock_generative_model.return_value = mock_model
|
||||
mock_model.start_chat.return_value = mock_chat
|
||||
|
||||
mock_part.to_dict.return_value = {
|
||||
"function_call": {"name": "function_name", "args": {"arg1": "value1", "arg2": "value2"}}
|
||||
}
|
||||
|
||||
# Set up a mock for the chat history item access and the text attribute return
|
||||
mock_history_part = MagicMock()
|
||||
mock_history_part.text = None
|
||||
mock_history_part.function_call.name = "function_name"
|
||||
mock_history_part.function_call.args = {"arg1": "value1", "arg2": "value2"}
|
||||
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
|
||||
|
||||
# Setup the mock to return a mocked chat response
|
||||
mock_chat.send_message.return_value = MagicMock(
|
||||
history=[
|
||||
MagicMock(
|
||||
parts=[
|
||||
MagicMock(
|
||||
function_call=MagicMock(name="function_name", arguments='{"arg1": "value1", "arg2": "value2"}')
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Call the create method
|
||||
response = gemini_client.create(
|
||||
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
|
||||
)
|
||||
|
||||
# Assertions to check if response is structured as expected
|
||||
assert (
|
||||
response.choices[0].message.tool_calls[0].function.name == "function_name"
|
||||
and json.loads(response.choices[0].message.tool_calls[0].function.arguments)["arg1"] == "value1"
|
||||
), "Response content should match expected output"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.vertexai.init")
|
||||
|
@ -320,7 +428,9 @@ def test_vertexai_create_response(mock_init, mock_generative_model, gemini_clien
|
|||
# Set up a mock for the chat history item access and the text attribute return
|
||||
mock_history_part = MagicMock()
|
||||
mock_history_part.text = "Example response"
|
||||
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
|
||||
mock_history_part.function_call = None
|
||||
mock_history_part.role = "model"
|
||||
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
|
||||
|
||||
# Setup the mock to return a mocked chat response
|
||||
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])
|
||||
|
@ -330,10 +440,60 @@ def test_vertexai_create_response(mock_init, mock_generative_model, gemini_clien
|
|||
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
|
||||
)
|
||||
|
||||
# Assertions to check if response is structured as expected
|
||||
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.VertexAIPart")
|
||||
@patch("autogen.oai.gemini.VertexAIContent")
|
||||
@patch("autogen.oai.gemini.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.vertexai.init")
|
||||
def test_vertexai_create_function_call_response(
|
||||
mock_init, mock_generative_model, mock_content, mock_part, gemini_client_with_credentials
|
||||
):
|
||||
# Mock the genai model configuration and creation process
|
||||
mock_chat = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_init.return_value = None
|
||||
mock_generative_model.return_value = mock_model
|
||||
mock_model.start_chat.return_value = mock_chat
|
||||
|
||||
mock_part.to_dict.return_value = {
|
||||
"function_call": {"name": "function_name", "args": {"arg1": "value1", "arg2": "value2"}}
|
||||
}
|
||||
|
||||
# Set up a mock for the chat history item access and the text attribute return
|
||||
mock_history_part = MagicMock()
|
||||
mock_history_part.text = None
|
||||
mock_history_part.function_call.name = "function_name"
|
||||
mock_history_part.function_call.args = {"arg1": "value1", "arg2": "value2"}
|
||||
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
|
||||
|
||||
# Setup the mock to return a mocked chat response
|
||||
mock_chat.send_message.return_value = MagicMock(
|
||||
history=[
|
||||
MagicMock(
|
||||
parts=[
|
||||
MagicMock(
|
||||
function_call=MagicMock(name="function_name", arguments='{"arg1": "value1", "arg2": "value2"}')
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Call the create method
|
||||
response = gemini_client_with_credentials.create(
|
||||
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
|
||||
)
|
||||
|
||||
# Assertions to check if response is structured as expected
|
||||
assert (
|
||||
response.choices[0].message.tool_calls[0].function.name == "function_name"
|
||||
and json.loads(response.choices[0].message.tool_calls[0].function.arguments)["arg1"] == "value1"
|
||||
), "Response content should match expected output"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.vertexai.init")
|
||||
|
@ -348,6 +508,8 @@ def test_vertexai_default_auth_create_response(mock_init, mock_generative_model,
|
|||
# Set up a mock for the chat history item access and the text attribute return
|
||||
mock_history_part = MagicMock()
|
||||
mock_history_part.text = "Example response"
|
||||
mock_history_part.function_call = None
|
||||
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
|
||||
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
|
||||
|
||||
# Setup the mock to return a mocked chat response
|
||||
|
@ -373,11 +535,11 @@ def test_create_vision_model_response(mock_configure, mock_generative_model, gem
|
|||
|
||||
# Set up a mock to simulate the vision model behavior
|
||||
mock_vision_response = MagicMock()
|
||||
mock_vision_part = MagicMock(text="Vision model output")
|
||||
mock_vision_part = MagicMock(text="Vision model output", function_call=None)
|
||||
|
||||
# Setting up the chain of return values for vision model response
|
||||
mock_vision_response._result.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = (
|
||||
mock_vision_part
|
||||
mock_vision_response._result.candidates.__getitem__.return_value.content.parts.__iter__.return_value = iter(
|
||||
[mock_vision_part]
|
||||
)
|
||||
mock_model.generate_content.return_value = mock_vision_response
|
||||
|
||||
|
@ -420,10 +582,12 @@ def test_vertexai_create_vision_model_response(mock_init, mock_generative_model,
|
|||
|
||||
# Set up a mock to simulate the vision model behavior
|
||||
mock_vision_response = MagicMock()
|
||||
mock_vision_part = MagicMock(text="Vision model output")
|
||||
mock_vision_part = MagicMock(text="Vision model output", function_call=None)
|
||||
|
||||
# Setting up the chain of return values for vision model response
|
||||
mock_vision_response.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = mock_vision_part
|
||||
mock_vision_response.candidates.__getitem__.return_value.content.parts.__iter__.return_value = iter(
|
||||
[mock_vision_part]
|
||||
)
|
||||
|
||||
mock_model.generate_content.return_value = mock_vision_response
|
||||
|
||||
|
|
Loading…
Reference in New Issue