Function Calling Support for Gemini - Part 2 (#3726)

* Added function calling support to GeminiClient

* Appending a continue message from model to alternate between user and model

* Fixed cost calculation to include <128K costing and new 1.5-flash model. Added test case for function_call.

* Add notebook with samples for Gemini

* Updated test case

* Fix to handle not dict response in function call

* Handle non dict function results and append dummy model message between function response and user message

* Fixing message order in gemini

* Append text as multiple parts instead of test concatenation

* Raising error for Union data types in function parameter

* Delete default key

* Update gemini.py for multiple tool calls + pre-commit formatting

* no function role

* start adding function calling config

* do not serialize tool_config

* improve tool config parsing

* add hint

* improve function calling config

* removunnecessary comments

* try removing allowed function names in tool config conversion

* fix tool config parsing with empty tools list

* improve logging and case handling with vertexai tool config parsing

* reset file

* check if text is in part

* fix empty part checking case

* fix bug with attribute handling

* skip test if gemini deps are not installed

---------

Co-authored-by: Arjun G <arjun@arjun-g.com>
Co-authored-by: Beibin Li <BeibinLi@users.noreply.github.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com>
This commit is contained in:
Zoltan Lux 2024-10-10 22:26:05 +02:00 committed by GitHub
parent 3ebd7aeec2
commit 32022b2df6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1290 additions and 94 deletions

View File

@ -32,6 +32,8 @@ Resources:
from __future__ import annotations
import base64
import copy
import json
import logging
import os
import random
@ -39,24 +41,39 @@ import re
import time
import warnings
from io import BytesIO
from typing import Any, Dict, List, Mapping, Union
from typing import Any, Dict, List, Union
import google.generativeai as genai
import requests
import vertexai
from google.ai.generativelanguage import Content, Part
from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool
from google.api_core.exceptions import InternalServerError
from google.auth.credentials import Credentials
from openai.types.chat import ChatCompletion
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import (
Content as VertexAIContent,
)
from vertexai.generative_models import (
FunctionDeclaration as VertexAIFunctionDeclaration,
)
from vertexai.generative_models import (
GenerationConfig as VertexAIGenerationConfig,
)
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
from vertexai.generative_models import (
Tool as VertexAITool,
)
from vertexai.generative_models import (
ToolConfig as VertexAIToolConfig,
)
logger = logging.getLogger(__name__)
@ -107,7 +124,7 @@ class GeminiClient:
Args:
api_key (str): The API key for using Gemini.
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
google_application_credentials (str): Path to the JSON service account key file of the service account.
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
can also be set instead of using this argument.
@ -171,6 +188,8 @@ class GeminiClient:
params.get("api_type", "google") # not used
messages = params.get("messages", [])
tools = params.get("tools", [])
tool_config = params.get("tool_config", {})
stream = params.get("stream", False)
n_response = params.get("n", 1)
system_instruction = params.get("system_instruction", None)
@ -183,6 +202,7 @@ class GeminiClient:
}
if self.use_vertexai:
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
tool_config = GeminiClient._to_vertexai_tool_config(tool_config, tools)
else:
safety_settings = params.get("safety_settings", {})
@ -198,12 +218,15 @@ class GeminiClient:
if "vision" not in model_name:
# A. create and call the chat model.
gemini_messages = self._oai_messages_to_gemini_messages(messages)
gemini_tools = self._oai_tools_to_gemini_tools(tools)
if self.use_vertexai:
model = GenerativeModel(
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
tools=gemini_tools,
tool_config=tool_config,
)
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
else:
@ -213,12 +236,13 @@ class GeminiClient:
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
tools=gemini_tools,
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
ans = None
ans: Union[Content, VertexAIContent] = None
try:
response = chat.send_message(
gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
@ -234,7 +258,7 @@ class GeminiClient:
raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}")
else:
# `ans = response.text` is unstable. Use the following code instead.
ans: str = chat.history[-1].parts[0].text
ans: Union[Content, VertexAIContent] = chat.history[-1]
break
if ans is None:
@ -262,7 +286,7 @@ class GeminiClient:
# Gemini's vision model does not support chat history yet
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1].parts)
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
user_message = self._oai_content_to_gemini_content(messages[-1])
if len(messages) > 2:
warnings.warn(
"Warning: Gemini's vision model does not support chat history yet.",
@ -273,16 +297,14 @@ class GeminiClient:
response = model.generate_content(user_message, stream=stream)
# ans = response.text
if self.use_vertexai:
ans: str = response.candidates[0].content.parts[0].text
ans: VertexAIContent = response.candidates[0].content
else:
ans: str = response._result.candidates[0].content.parts[0].text
ans: Content = response._result.candidates[0].content
prompt_tokens = model.count_tokens(user_message).total_tokens
completion_tokens = model.count_tokens(ans).total_tokens
completion_tokens = model.count_tokens(ans.parts[0].text).total_tokens
# 3. convert output
message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
choices = [Choice(finish_reason="stop", index=0, message=message)]
choices = self._gemini_content_to_oai_choices(ans)
response_oai = ChatCompletion(
id=str(random.randint(0, 1000)),
@ -295,31 +317,87 @@ class GeminiClient:
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
cost=self._calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
)
return response_oai
def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List:
# If str is not a json string return str as is
def _to_json(self, str) -> dict:
try:
return json.loads(str)
except ValueError:
return str
def _oai_content_to_gemini_content(self, message: Dict[str, Any]) -> List:
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
if content == "":
content = "empty" # Empty content is not allowed.
if isinstance(message["content"], str):
if message["content"] == "":
message["content"] = "empty" # Empty content is not allowed.
if self.use_vertexai:
rst.append(VertexAIPart.from_text(content))
rst.append(VertexAIPart.from_text(message["content"]))
else:
rst.append(Part(text=content))
rst.append(Part(text=message["content"]))
return rst
assert isinstance(content, list)
if "tool_calls" in message:
if self.use_vertexai:
for tool_call in message["tool_calls"]:
rst.append(
VertexAIPart.from_dict(
{
"functionCall": {
"name": tool_call["function"]["name"],
"args": json.loads(tool_call["function"]["arguments"]),
}
}
)
)
else:
for tool_call in message["tool_calls"]:
rst.append(
Part(
function_call=FunctionCall(
name=tool_call["function"]["name"],
args=json.loads(tool_call["function"]["arguments"]),
)
)
)
return rst
for msg in content:
if message["role"] == "tool":
if self.use_vertexai:
rst.append(
VertexAIPart.from_function_response(
name=message["name"], response={"result": self._to_json(message["content"])}
)
)
else:
rst.append(
Part(
function_response=FunctionResponse(
name=message["name"], response={"result": self._to_json(message["content"])}
)
)
)
return rst
if isinstance(message["content"], str):
if self.use_vertexai:
rst.append(VertexAIPart.from_text(message["content"]))
else:
rst.append(Part(text=message["content"]))
return rst
assert isinstance(message["content"], list)
for msg in message["content"]:
if isinstance(msg, dict):
assert "type" in msg, f"Missing 'type' field in message: {msg}"
if msg["type"] == "text":
if self.use_vertexai:
rst.append(VertexAIPart.from_text(text=msg["text"]))
rst.append(VertexAIPart.from_text(msg["text"]))
else:
rst.append(Part(text=msg["text"]))
elif msg["type"] == "image_url":
@ -340,34 +418,32 @@ class GeminiClient:
raise ValueError(f"Unsupported message type: {type(msg)}")
return rst
def _concat_parts(self, parts: List[Part]) -> List:
"""Concatenate parts with the same type.
If two adjacent parts both have the "text" attribute, then it will be joined into one part.
"""
if not parts:
return []
def _calculate_gemini_cost(self, input_tokens: int, output_tokens: int, model_name: str) -> float:
if "1.5-pro" in model_name:
if (input_tokens + output_tokens) <= 128000:
# "gemini-1.5-pro"
# When total tokens is less than 128K cost is $3.5 per million input tokens and $10.5 per million output tokens
return 3.5 * input_tokens / 1e6 + 10.5 * output_tokens / 1e6
# "gemini-1.5-pro"
# Cost is $7 per million input tokens and $21 per million output tokens
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
concatenated_parts = []
previous_part = parts[0]
if "1.5-flash" in model_name:
if (input_tokens + output_tokens) <= 128000:
# "gemini-1.5-flash"
# Cost is $0.35 per million input tokens and $1.05 per million output tokens
return 0.35 * input_tokens / 1e6 + 1.05 * output_tokens / 1e6
# "gemini-1.5-flash"
# When total tokens is less than 128K cost is $0.70 per million input tokens and $2.10 per million output tokens
return 0.70 * input_tokens / 1e6 + 2.10 * output_tokens / 1e6
for current_part in parts[1:]:
if previous_part.text != "":
if self.use_vertexai:
previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
else:
previous_part.text += current_part.text
else:
concatenated_parts.append(previous_part)
previous_part = current_part
if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
warnings.warn(
f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning
)
if previous_part.text == "":
if self.use_vertexai:
previous_part = VertexAIPart.from_text("empty")
else:
previous_part.text = "empty" # Empty content is not allowed.
concatenated_parts.append(previous_part)
return concatenated_parts
# Cost is $0.5 per million input tokens and $1.5 per million output tokens
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Gemini format.
@ -376,38 +452,154 @@ class GeminiClient:
"""
prev_role = None
rst = []
curr_parts = []
for i, message in enumerate(messages):
parts = self._oai_content_to_gemini_content(message["content"])
role = "user" if message["role"] in ["user", "system"] else "model"
if (prev_role is None) or (role == prev_role):
curr_parts += parts
elif role != prev_role:
if self.use_vertexai:
rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
else:
rst.append(Content(parts=curr_parts, role=prev_role))
curr_parts = parts
prev_role = role
# handle the last message
if self.use_vertexai:
rst.append(VertexAIContent(parts=curr_parts, role=role))
else:
rst.append(Content(parts=curr_parts, role=role))
def append_parts(parts, role):
if self.use_vertexai:
rst.append(VertexAIContent(parts=parts, role=role))
else:
rst.append(Content(parts=parts, role=role))
def append_text_to_last(text):
if self.use_vertexai:
rst[-1] = VertexAIContent(parts=[*rst[-1].parts, VertexAIPart.from_text(text)], role=rst[-1].role)
else:
rst[-1] = Content(parts=[*rst[-1].parts, Part(text=text)], role=rst[-1].role)
def is_function_call(parts):
return self.use_vertexai and parts[0].function_call or not self.use_vertexai and "function_call" in parts[0]
for i, message in enumerate(messages):
# Since the tool call message does not have the "name" field, we need to find the corresponding tool message.
if message["role"] == "tool":
message["name"] = [
m["tool_calls"][i]["function"]["name"]
for m in messages
if "tool_calls" in m
for i, tc in enumerate(m["tool_calls"])
if tc["id"] == message["tool_call_id"]
][0]
parts = self._oai_content_to_gemini_content(message)
role = "user" if message["role"] in ["user", "system"] else "model"
# In Gemini if the current message is a function call then previous message should not be a model message.
if is_function_call(parts):
# If the previous message is a model message then add a dummy "continue" user message before the function call
if prev_role == "model":
append_parts(self._oai_content_to_gemini_content({"content": "continue"}), "user")
append_parts(parts, role)
# In Gemini if the current message is a function response then next message should be a model message.
elif role == "function":
append_parts(parts, "function")
# If the next message is not a model message then add a dummy "continue" model message after the function response
if len(messages) > (i + 1) and messages[i + 1]["role"] in ["user", "system"]:
append_parts(self._oai_content_to_gemini_content({"content": "continue"}), "model")
# If the role is the same as the previous role and both are text messages then concatenate the text
elif role == prev_role:
append_text_to_last(parts[0].text)
# If this is first message or the role is different from the previous role then append the parts
else:
# If the previous text message is empty then update the text to "empty" as Gemini does not support empty messages
if (
(len(rst) > 0)
and hasattr(rst[-1].parts[0], "_raw_part")
and hasattr(rst[-1].parts[0]._raw_part, "text")
and (rst[-1].parts[0]._raw_part.text == "")
):
append_text_to_last("empty")
append_parts(parts, role)
prev_role = role
# The Gemini is restrict on order of roles, such that
# 1. The messages should be interleaved between user and model.
# 2. The last message must be from the user role.
# We add a dummy message "continue" if the last role is not the user.
if rst[-1].role != "user":
if rst[-1].role != "user" and rst[-1].role != "function":
if self.use_vertexai:
rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user"))
rst.append(
VertexAIContent(parts=self._oai_content_to_gemini_content({"content": "continue"}), role="user")
)
else:
rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user"))
rst.append(Content(parts=self._oai_content_to_gemini_content({"content": "continue"}), role="user"))
return rst
def _oai_tools_to_gemini_tools(self, tools: List[Dict[str, Any]]) -> List[Tool]:
"""Convert tools from OAI format to Gemini format."""
if len(tools) == 0:
return None
function_declarations = []
for tool in tools:
if self.use_vertexai:
function_declaration = VertexAIFunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"]["description"],
parameters=tool["function"]["parameters"],
)
else:
function_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"]["description"],
parameters=self._oai_function_parameters_to_gemini_function_parameters(
copy.deepcopy(tool["function"]["parameters"])
),
)
function_declarations.append(function_declaration)
if self.use_vertexai:
return [VertexAITool(function_declarations=function_declarations)]
else:
return [Tool(function_declarations=function_declarations)]
def _oai_function_parameters_to_gemini_function_parameters(
self, function_definition: dict[str, any]
) -> dict[str, any]:
"""
Convert OpenAPI function definition parameters to Gemini function parameters definition.
The type key is renamed to type_ and the value is capitalized.
"""
assert "anyOf" not in function_definition, "Union types are not supported for function parameter in Gemini."
# Delete the default key as it is not supported in Gemini
if "default" in function_definition:
del function_definition["default"]
function_definition["type_"] = function_definition["type"].upper()
del function_definition["type"]
if "properties" in function_definition:
for key in function_definition["properties"]:
function_definition["properties"][key] = self._oai_function_parameters_to_gemini_function_parameters(
function_definition["properties"][key]
)
if "items" in function_definition:
function_definition["items"] = self._oai_function_parameters_to_gemini_function_parameters(
function_definition["items"]
)
return function_definition
def _gemini_content_to_oai_choices(self, response: Union[Content, VertexAIContent]) -> List[Choice]:
"""Convert response from Gemini format to OAI format."""
text = None
tool_calls = []
for part in response.parts:
if part.function_call:
if self.use_vertexai:
arguments = VertexAIPart.to_dict(part)["function_call"]["args"]
else:
arguments = Part.to_dict(part)["function_call"]["args"]
tool_calls.append(
ChatCompletionMessageToolCall(
id=str(random.randint(0, 1000)),
type="function",
function=Function(name=part.function_call.name, arguments=json.dumps(arguments)),
)
)
elif part.text:
text = part.text
message = ChatCompletionMessage(
role="assistant", content=text, function_call=None, tool_calls=tool_calls if len(tool_calls) > 0 else None
)
return [Choice(finish_reason="tool_calls" if tool_calls else "stop", index=0, message=message)]
@staticmethod
def _to_vertexai_safety_settings(safety_settings):
"""Convert safety settings to VertexAI format if needed,
@ -437,6 +629,49 @@ class GeminiClient:
else:
return safety_settings
@staticmethod
def _to_vertexai_tool_config(tool_config, tools):
"""Convert tool config to VertexAI format,
like when specifying them in the OAI_CONFIG_LIST
"""
if (
isinstance(tool_config, dict)
and (len(tool_config) > 0)
and all([isinstance(tool_config[tool_config_entry], dict) for tool_config_entry in tool_config])
):
if (
tool_config["function_calling_config"]["mode"]
not in VertexAIToolConfig.FunctionCallingConfig.Mode.__members__
):
invalid_mode = tool_config["function_calling_config"]
logger.error(f"Function calling mode {invalid_mode} is invalid")
return None
else:
# Currently, there is only function calling config
func_calling_config_params = {}
func_calling_config_params["mode"] = VertexAIToolConfig.FunctionCallingConfig.Mode[
tool_config["function_calling_config"]["mode"]
]
if (
(func_calling_config_params["mode"] == VertexAIToolConfig.FunctionCallingConfig.Mode.ANY)
and (len(tools) > 0)
and all(["function_name" in tool for tool in tools])
):
# The function names are not yet known when parsing the OAI_CONFIG_LIST
func_calling_config_params["allowed_function_names"] = [tool["function_name"] for tool in tools]
vertexai_tool_config = VertexAIToolConfig(
function_calling_config=VertexAIToolConfig.FunctionCallingConfig(**func_calling_config_params)
)
return vertexai_tool_config
elif isinstance(tool_config, VertexAIToolConfig):
return tool_config
elif len(tool_config) == 0 and len(tools) == 0:
logger.debug("VertexAI tool config is empty!")
return None
else:
logger.error("Invalid VertexAI tool config!")
return None
def _to_pil(data: str) -> Image.Image:
"""
@ -470,16 +705,3 @@ def get_image_data(image_file: str, use_b64=True) -> bytes:
return base64.b64encode(content).decode("utf-8")
else:
return content
def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
if "1.5" in model_name or "gemini-experimental" in model_name:
# "gemini-1.5-pro-preview-0409"
# Cost is $7 per million input tokens and $21 per million output tokens
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
# Cost is $0.5 per million input tokens and $1.5 per million output tokens
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6

View File

@ -21,6 +21,7 @@ NON_CACHE_KEY = [
"azure_ad_token",
"azure_ad_token_provider",
"credentials",
"tool_config",
]
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {

File diff suppressed because one or more lines are too long

View File

@ -1,3 +1,4 @@
import json
import os
from unittest.mock import MagicMock, patch
@ -10,7 +11,9 @@ try:
from google.cloud.aiplatform.initializer import global_config as vertexai_global_config
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
from vertexai.generative_models import ToolConfig as VertexAIToolConfig
from autogen.oai.gemini import GeminiClient
@ -20,6 +23,8 @@ except ImportError:
VertexAIHarmBlockThreshold = object
VertexAIHarmCategory = object
VertexAISafetySetting = object
VertexAIPart = object
VertexAIToolConfig = object
vertexai_global_config = object
InternalServerError = object
skip = True
@ -234,8 +239,6 @@ def test_vertexai_safety_setting_list(gemini_client):
for category in harm_categories
]
print(safety_settings)
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
@ -250,6 +253,59 @@ def test_vertexai_safety_setting_list(gemini_client):
assert all(settings_comparison), "Converted safety settings are incorrect"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_vertexai_tool_config(gemini_client):
tools = [{"function_name": "calculator"}]
tool_config = {"function_calling_config": {"mode": "ANY"}}
expected_tool_config = VertexAIToolConfig(
function_calling_config=VertexAIToolConfig.FunctionCallingConfig(
mode=VertexAIToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["calculator"],
)
)
converted_tool_config = GeminiClient._to_vertexai_tool_config(tool_config, tools)
converted_mode = converted_tool_config._gapic_tool_config.function_calling_config.mode
expected_mode = expected_tool_config._gapic_tool_config.function_calling_config.mode
converted_allowed_func = converted_tool_config._gapic_tool_config.function_calling_config.allowed_function_names
expected_allowed_func = expected_tool_config._gapic_tool_config.function_calling_config.allowed_function_names
assert converted_mode == expected_mode, "Function calling mode is not converted correctly"
assert (
converted_allowed_func == expected_allowed_func
), "Function calling allowed function names is not converted correctly"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_vertexai_tool_config_no_functions(gemini_client):
tools = []
tool_config = {"function_calling_config": {"mode": "ANY"}}
expected_tool_config = VertexAIToolConfig(
function_calling_config=VertexAIToolConfig.FunctionCallingConfig(
mode=VertexAIToolConfig.FunctionCallingConfig.Mode.ANY,
)
)
converted_tool_config = GeminiClient._to_vertexai_tool_config(tool_config, tools)
converted_mode = converted_tool_config._gapic_tool_config.function_calling_config.mode
expected_mode = expected_tool_config._gapic_tool_config.function_calling_config.mode
converted_allowed_func = converted_tool_config._gapic_tool_config.function_calling_config.allowed_function_names
expected_allowed_func = expected_tool_config._gapic_tool_config.function_calling_config.allowed_function_names
assert converted_mode == expected_mode, "Function calling mode is not converted correctly"
assert (
converted_allowed_func == expected_allowed_func
), "Function calling allowed function names is not converted correctly"
# Test error handling
@patch("autogen.oai.gemini.genai")
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@ -279,9 +335,10 @@ def test_cost_calculation(gemini_client, mock_response):
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.Content")
@patch("autogen.oai.gemini.genai.GenerativeModel")
@patch("autogen.oai.gemini.genai.configure")
def test_create_response(mock_configure, mock_generative_model, gemini_client):
def test_create_response(mock_configure, mock_generative_model, mock_content, gemini_client):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
@ -292,6 +349,8 @@ def test_create_response(mock_configure, mock_generative_model, gemini_client):
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_history_part.function_call = None
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
# Setup the mock to return a mocked chat response
@ -306,6 +365,55 @@ def test_create_response(mock_configure, mock_generative_model, gemini_client):
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.Part")
@patch("autogen.oai.gemini.Content")
@patch("autogen.oai.gemini.genai.GenerativeModel")
@patch("autogen.oai.gemini.genai.configure")
def test_create_function_call_response(mock_configure, mock_generative_model, mock_content, mock_part, gemini_client):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
mock_configure.return_value = None
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat
mock_part.to_dict.return_value = {
"function_call": {"name": "function_name", "args": {"arg1": "value1", "arg2": "value2"}}
}
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = None
mock_history_part.function_call.name = "function_name"
mock_history_part.function_call.args = {"arg1": "value1", "arg2": "value2"}
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(
history=[
MagicMock(
parts=[
MagicMock(
function_call=MagicMock(name="function_name", arguments='{"arg1": "value1", "arg2": "value2"}')
)
]
)
]
)
# Call the create method
response = gemini_client.create(
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
# Assertions to check if response is structured as expected
assert (
response.choices[0].message.tool_calls[0].function.name == "function_name"
and json.loads(response.choices[0].message.tool_calls[0].function.arguments)["arg1"] == "value1"
), "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.GenerativeModel")
@patch("autogen.oai.gemini.vertexai.init")
@ -320,7 +428,9 @@ def test_vertexai_create_response(mock_init, mock_generative_model, gemini_clien
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
mock_history_part.function_call = None
mock_history_part.role = "model"
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])
@ -330,10 +440,60 @@ def test_vertexai_create_response(mock_init, mock_generative_model, gemini_clien
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
# Assertions to check if response is structured as expected
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.VertexAIPart")
@patch("autogen.oai.gemini.VertexAIContent")
@patch("autogen.oai.gemini.GenerativeModel")
@patch("autogen.oai.gemini.vertexai.init")
def test_vertexai_create_function_call_response(
mock_init, mock_generative_model, mock_content, mock_part, gemini_client_with_credentials
):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
mock_init.return_value = None
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat
mock_part.to_dict.return_value = {
"function_call": {"name": "function_name", "args": {"arg1": "value1", "arg2": "value2"}}
}
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = None
mock_history_part.function_call.name = "function_name"
mock_history_part.function_call.args = {"arg1": "value1", "arg2": "value2"}
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(
history=[
MagicMock(
parts=[
MagicMock(
function_call=MagicMock(name="function_name", arguments='{"arg1": "value1", "arg2": "value2"}')
)
]
)
]
)
# Call the create method
response = gemini_client_with_credentials.create(
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
# Assertions to check if response is structured as expected
assert (
response.choices[0].message.tool_calls[0].function.name == "function_name"
and json.loads(response.choices[0].message.tool_calls[0].function.arguments)["arg1"] == "value1"
), "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.GenerativeModel")
@patch("autogen.oai.gemini.vertexai.init")
@ -348,6 +508,8 @@ def test_vertexai_default_auth_create_response(mock_init, mock_generative_model,
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_history_part.function_call = None
mock_chat.history.__getitem__.return_value.parts.__iter__.return_value = iter([mock_history_part])
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
# Setup the mock to return a mocked chat response
@ -373,11 +535,11 @@ def test_create_vision_model_response(mock_configure, mock_generative_model, gem
# Set up a mock to simulate the vision model behavior
mock_vision_response = MagicMock()
mock_vision_part = MagicMock(text="Vision model output")
mock_vision_part = MagicMock(text="Vision model output", function_call=None)
# Setting up the chain of return values for vision model response
mock_vision_response._result.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = (
mock_vision_part
mock_vision_response._result.candidates.__getitem__.return_value.content.parts.__iter__.return_value = iter(
[mock_vision_part]
)
mock_model.generate_content.return_value = mock_vision_response
@ -420,10 +582,12 @@ def test_vertexai_create_vision_model_response(mock_init, mock_generative_model,
# Set up a mock to simulate the vision model behavior
mock_vision_response = MagicMock()
mock_vision_part = MagicMock(text="Vision model output")
mock_vision_part = MagicMock(text="Vision model output", function_call=None)
# Setting up the chain of return values for vision model response
mock_vision_response.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = mock_vision_part
mock_vision_response.candidates.__getitem__.return_value.content.parts.__iter__.return_value = iter(
[mock_vision_part]
)
mock_model.generate_content.return_value = mock_vision_response