Enhance vertexai integration (#3086)

* switch to officially supported Vertex AI message sending + safety setting converion for vertexai

* add system instructions

* switch to officially supported Vertex AI message sending + safety setting converion for vertexai

* fix bug in safety settings conversion

* add missing system instructions

* add safety settings to send message

* add support for credentials objects

* add type checkingchange project_id to project arg

* add more tests

* fix mock creation in test

* extend docstring

* fix errors with gemini message format in chats

* add option for vertexai response validation setting & improve docstring

* readding empty message handling

* add more tests

* extend and improve gemini vertexai jupyter notebook

* rename project arg to project_id and GOOGLE_API_KEY env var to GOOGLE_GEMINI_API_KEY

* adjust docstring formatting
This commit is contained in:
Zoltan Lux 2024-07-23 18:37:48 +02:00 committed by GitHub
parent 1daf852f86
commit a5e5be73b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 518 additions and 104 deletions

View File

@ -107,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
return chat_order
def _post_process_carryover_item(carryover_item):
if isinstance(carryover_item, str):
return carryover_item
elif isinstance(carryover_item, dict) and "content" in carryover_item:
return str(carryover_item["content"])
else:
return str(carryover_item)
def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
iostream = IOStream.get_default()
@ -116,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
UserWarning,
)
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)

View File

@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Ty
from openai import BadRequestError
from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
from .._pydantic import model_dump
@ -2364,7 +2365,7 @@ class ConversableAgent(LLMAgent):
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."

View File

@ -6,7 +6,7 @@ Example:
"config_list": [{
"api_type": "google",
"model": "gemini-pro",
"api_key": os.environ.get("GOOGLE_API_KEY"),
"api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
"safety_settings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
@ -32,6 +32,7 @@ Resources:
from __future__ import annotations
import base64
import logging
import os
import random
import re
@ -45,13 +46,19 @@ import requests
import vertexai
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from google.auth.credentials import Credentials
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
logger = logging.getLogger(__name__)
class GeminiClient:
@ -81,19 +88,26 @@ class GeminiClient:
vertexai_init_args["project"] = params["project_id"]
if "location" in params:
vertexai_init_args["location"] = params["location"]
if "credentials" in params:
assert isinstance(
params["credentials"], Credentials
), "Object type google.auth.credentials.Credentials is expected!"
vertexai_init_args["credentials"] = params["credentials"]
if vertexai_init_args:
vertexai.init(**vertexai_init_args)
def __init__(self, **kwargs):
"""Uses either either api_key for authentication from the LLM config
(specifying the GOOGLE_API_KEY environment variable also works),
(specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
where project_id and location can also be passed as parameters. Service account key file can also be used.
If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
then the default credentials will be used, which could be a personal account if the user is already authenticated in,
like in Google Cloud Shell.
Args:
api_key (str): The API key for using Gemini.
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
google_application_credentials (str): Path to the JSON service account key file of the service account.
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
can also be set instead of using this argument.
@ -103,7 +117,7 @@ class GeminiClient:
"""
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
self.api_key = os.getenv("GOOGLE_API_KEY")
self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
if self.api_key is None:
self.use_vertexai = True
self._initialize_vertexai(**kwargs)
@ -159,12 +173,17 @@ class GeminiClient:
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
system_instruction = params.get("system_instruction", None)
response_validation = params.get("response_validation", True)
generation_config = {
gemini_term: params[autogen_term]
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
if autogen_term in params
}
if self.use_vertexai:
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
else:
safety_settings = params.get("safety_settings", {})
if stream:
@ -181,12 +200,19 @@ class GeminiClient:
gemini_messages = self._oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
else:
# we use chat model by default
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
@ -194,7 +220,9 @@ class GeminiClient:
for attempt in range(max_retries):
ans = None
try:
response = chat.send_message(gemini_messages[-1], stream=stream)
response = chat.send_message(
gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
@ -218,16 +246,22 @@ class GeminiClient:
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
else:
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
# Gemini's vision model does not support chat history yet
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
# response = chat.send_message(gemini_messages[-1].parts)
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
if len(messages) > 2:
warnings.warn(
@ -270,6 +304,8 @@ class GeminiClient:
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
if content == "":
content = "empty" # Empty content is not allowed.
if self.use_vertexai:
rst.append(VertexAIPart.from_text(content))
else:
@ -372,6 +408,35 @@ class GeminiClient:
return rst
@staticmethod
def _to_vertexai_safety_settings(safety_settings):
"""Convert safety settings to VertexAI format if needed,
like when specifying them in the OAI_CONFIG_LIST
"""
if isinstance(safety_settings, list) and all(
[
isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
for safety_setting in safety_settings
]
):
vertexai_safety_settings = []
for safety_setting in safety_settings:
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
invalid_category = safety_setting["category"]
logger.error(f"Safety setting category {invalid_category} is invalid")
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
invalid_threshold = safety_setting["threshold"]
logger.error(f"Safety threshold {invalid_threshold} is invalid")
else:
vertexai_safety_setting = VertexAISafetySetting(
category=safety_setting["category"],
threshold=safety_setting["threshold"],
)
vertexai_safety_settings.append(vertexai_safety_setting)
return vertexai_safety_settings
else:
return safety_settings
def _to_pil(data: str) -> Image.Image:
"""

View File

@ -13,7 +13,15 @@ from openai import OpenAI
from openai.types.beta.assistant import Assistant
from packaging.version import parse
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
NON_CACHE_KEY = [
"api_key",
"base_url",
"api_type",
"api_version",
"azure_ad_token",
"azure_ad_token_provider",
"credentials",
]
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/

View File

@ -10,6 +10,7 @@ from typing_extensions import Annotated
import autogen
from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent, filter_config, initiate_chats
from autogen.agentchat.chat import _post_process_carryover_item
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import reason, skip_openai # noqa: E402
@ -620,6 +621,15 @@ def test_udf_message_in_chats():
print(chat_results[1].summary, chat_results[1].cost)
def test_post_process_carryover_item():
gemini_carryover_item = {"content": "How can I help you?", "role": "model"}
assert (
_post_process_carryover_item(gemini_carryover_item) == gemini_carryover_item["content"]
), "Incorrect carryover postprocessing"
carryover_item = "How can I help you?"
assert _post_process_carryover_item(carryover_item) == carryover_item, "Incorrect carryover postprocessing"
if __name__ == "__main__":
test_chats()
# test_chats_general()
@ -628,3 +638,4 @@ if __name__ == "__main__":
# test_chats_w_func()
# test_chat_messages_for_summary()
# test_udf_message_in_chats()
test_post_process_carryover_item()

View File

@ -1463,6 +1463,58 @@ def test_adding_duplicate_function_warning():
)
def test_process_gemini_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover_content = "How can I help you?"
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=gemini_kwargs)
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"
def test_process_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover = "How can I help you?"
kwargs = {"carryover": carryover}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
carryover_l = ["How can I help you?"]
kwargs = {"carryover": carryover_l}
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
proc_content_empty_carryover = dummy_agent_1._process_carryover(content=content, kwargs={"carryover": None})
assert proc_content_empty_carryover == content, "Incorrect carryover processing"
def test_handle_gemini_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant"
carryover_content = "How can I help you?"
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=gemini_kwargs)
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"
def test_handle_carryover():
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
content = "I am your assistant."
carryover = "How can I help you?"
kwargs = {"carryover": carryover}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
carryover_l = ["How can I help you?"]
kwargs = {"carryover": carryover_l}
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
proc_content_empty_carryover = dummy_agent_1._handle_carryover(message=content, kwargs={"carryover": None})
assert proc_content_empty_carryover == content, "Incorrect carryover processing"
if __name__ == "__main__":
# test_trigger()
# test_context()
@ -1473,6 +1525,10 @@ if __name__ == "__main__":
# test_max_turn()
# test_process_before_send()
# test_message_func()
test_summary()
test_adding_duplicate_function_warning()
# test_function_registration_e2e_sync()
test_process_gemini_carryover()
test_process_carryover()

View File

@ -1,15 +1,26 @@
import os
from unittest.mock import MagicMock, patch
import pytest
try:
import google.auth
from google.api_core.exceptions import InternalServerError
from google.auth.credentials import Credentials
from google.cloud.aiplatform.initializer import global_config as vertexai_global_config
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
from autogen.oai.gemini import GeminiClient
skip = False
except ImportError:
GeminiClient = object
VertexAIHarmBlockThreshold = object
VertexAIHarmCategory = object
VertexAISafetySetting = object
vertexai_global_config = object
InternalServerError = object
skip = True
@ -30,7 +41,24 @@ def mock_response():
@pytest.fixture
def gemini_client():
return GeminiClient(api_key="fake_api_key")
system_message = [
"You are a helpful AI assistant.",
]
return GeminiClient(api_key="fake_api_key", system_message=system_message)
@pytest.fixture
def gemini_google_auth_default_client():
system_message = [
"You are a helpful AI assistant.",
]
return GeminiClient(system_message=system_message)
@pytest.fixture
def gemini_client_with_credentials():
mock_credentials = MagicMock(Credentials)
return GeminiClient(credentials=mock_credentials)
# Test compute location initialization and configuration
@ -42,9 +70,13 @@ def test_compute_location_initialization():
) # Should raise an AssertionError due to specifying API key and compute location
@pytest.fixture
def gemini_google_auth_default_client():
return GeminiClient()
# Test project initialization and configuration
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_project_initialization():
with pytest.raises(AssertionError):
GeminiClient(
api_key="fake_api_key", project_id="fake-project-id"
) # Should raise an AssertionError due to specifying API key and compute location
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@ -52,6 +84,23 @@ def test_valid_initialization(gemini_client):
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_google_application_credentials_initialization():
GeminiClient(google_application_credentials="credentials.json", project_id="fake-project-id")
assert (
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] == "credentials.json"
), "Incorrect Google Application Credentials initialization"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_vertexai_initialization():
mock_credentials = MagicMock(Credentials)
GeminiClient(credentials=mock_credentials, project_id="fake-project-id", location="us-west1")
assert vertexai_global_config.location == "us-west1", "Incorrect VertexAI location initialization"
assert vertexai_global_config.project == "fake-project-id", "Incorrect VertexAI project initialization"
assert vertexai_global_config.credentials == mock_credentials, "Incorrect VertexAI credentials initialization"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_gemini_message_handling(gemini_client):
messages = [
@ -94,6 +143,113 @@ def test_gemini_message_handling(gemini_client):
assert converted_messages[i].parts[j].text == part, "Incorrect mapped message text"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_gemini_empty_message_handling(gemini_client):
messages = [
{"role": "system", "content": "You are my personal assistant."},
{"role": "model", "content": "How can I help you?"},
{"role": "user", "content": ""},
{
"role": "model",
"content": "Please provide me with some context or a request! I need more information to assist you.",
},
{"role": "user", "content": ""},
]
converted_messages = gemini_client._oai_messages_to_gemini_messages(messages)
assert converted_messages[-3].parts[0].text == "empty", "Empty message is not converted to 'empty' correctly"
assert converted_messages[-1].parts[0].text == "empty", "Empty message is not converted to 'empty' correctly"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_vertexai_safety_setting_conversion(gemini_client):
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
]
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
harm_categories = [
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT,
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH,
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
]
expected_safety_settings = [
VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH)
for category in harm_categories
]
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
for i, expected_setting in enumerate(expected_safety_settings):
converted_setting = converted_safety_settings[i]
yield expected_setting.to_dict() == converted_setting.to_dict()
assert len(converted_safety_settings) == len(
expected_safety_settings
), "The length of the safety settings is incorrect"
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
assert all(settings_comparison), "Converted safety settings are incorrect"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_vertexai_default_safety_settings_dict(gemini_client):
safety_settings = {
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH,
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH,
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH,
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH,
}
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
expected_safety_settings = {
category: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH for category in safety_settings.keys()
}
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
for expected_setting_key in expected_safety_settings.keys():
expected_setting = expected_safety_settings[expected_setting_key]
converted_setting = converted_safety_settings[expected_setting_key]
yield expected_setting == converted_setting
assert len(converted_safety_settings) == len(
expected_safety_settings
), "The length of the safety settings is incorrect"
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
assert all(settings_comparison), "Converted safety settings are incorrect"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_vertexai_safety_setting_list(gemini_client):
harm_categories = [
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT,
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH,
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
]
expected_safety_settings = safety_settings = [
VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH)
for category in harm_categories
]
print(safety_settings)
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
for i, expected_setting in enumerate(expected_safety_settings):
converted_setting = converted_safety_settings[i]
yield expected_setting.to_dict() == converted_setting.to_dict()
assert len(converted_safety_settings) == len(
expected_safety_settings
), "The length of the safety settings is incorrect"
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
assert all(settings_comparison), "Converted safety settings are incorrect"
# Test error handling
@patch("autogen.oai.gemini.genai")
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@ -150,6 +306,62 @@ def test_create_response(mock_configure, mock_generative_model, gemini_client):
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.GenerativeModel")
@patch("autogen.oai.gemini.vertexai.init")
def test_vertexai_create_response(mock_init, mock_generative_model, gemini_client_with_credentials):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
mock_init.return_value = None
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])
# Call the create method
response = gemini_client_with_credentials.create(
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
# Assertions to check if response is structured as expected
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.GenerativeModel")
@patch("autogen.oai.gemini.vertexai.init")
def test_vertexai_default_auth_create_response(mock_init, mock_generative_model, gemini_google_auth_default_client):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
mock_init.return_value = None
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])
# Call the create method
response = gemini_google_auth_default_client.create(
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
# Assertions to check if response is structured as expected
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.genai.GenerativeModel")
@patch("autogen.oai.gemini.genai.configure")
@ -195,3 +407,49 @@ def test_create_vision_model_response(mock_configure, mock_generative_model, gem
assert (
response.choices[0].message.content == "Vision model output"
), "Response content should match expected output from vision model"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.GenerativeModel")
@patch("autogen.oai.gemini.vertexai.init")
def test_vertexai_create_vision_model_response(mock_init, mock_generative_model, gemini_google_auth_default_client):
# Mock the genai model configuration and creation process
mock_model = MagicMock()
mock_init.return_value = None
mock_generative_model.return_value = mock_model
# Set up a mock to simulate the vision model behavior
mock_vision_response = MagicMock()
mock_vision_part = MagicMock(text="Vision model output")
# Setting up the chain of return values for vision model response
mock_vision_response.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = mock_vision_part
mock_model.generate_content.return_value = mock_vision_response
# Call the create method with vision model parameters
response = gemini_google_auth_default_client.create(
{
"model": "gemini-pro-vision", # Vision model name
"messages": [
{
"content": [
{"type": "text", "text": "Let's play a game."},
{
"type": "image_url",
"image_url": {
"url": ""
},
},
],
"role": "user",
}
], # Assuming a simple content input for vision
"stream": False,
}
)
# Assertions to check if response is structured as expected
assert (
response.choices[0].message.content == "Vision model output"
), "Response content should match expected output from vision model"

View File

@ -64,7 +64,7 @@
" },\n",
" {\n",
" \"model\": \"gemini-1.5-pro\",\n",
" \"project\": \"your-awesome-google-cloud-project-id\",\n",
" \"project_id\": \"your-awesome-google-cloud-project-id\",\n",
" \"location\": \"us-west1\",\n",
" \"google_application_credentials\": \"your-google-service-account-key.json\"\n",
" },\n",

View File

@ -14,9 +14,14 @@
"\n",
"## Requirements\n",
"\n",
"AutoGen requires `Python>=3.8`. To run this notebook example, please install with the [gemini] option:\n",
"Install AutoGen with Gemini features:\n",
"```bash\n",
"pip install \"pyautogen[gemini]\"\n",
"pip install pyautogen[gemini]\n",
"```\n",
"\n",
"### Install other Dependencies of this Notebook\n",
"```bash\n",
"pip install chromadb markdownify pypdf\n",
"```\n",
"\n",
"### Google Cloud Account\n",
@ -66,41 +71,6 @@
" * Please consider restricting the permissions on the key file. For example, you could run `chmod 600 autogen-with-gemini-service-account-key.json` if your keyfile is called autogen-with-gemini-service-account-key.json."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2023-02-13T23:40:52.317406Z",
"iopub.status.busy": "2023-02-13T23:40:52.316561Z",
"iopub.status.idle": "2023-02-13T23:40:52.321193Z",
"shell.execute_reply": "2023-02-13T23:40:52.320628Z"
}
},
"outputs": [],
"source": [
"# %pip install \"pyautogen[gemini]\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2023-02-13T23:40:54.634335Z",
"iopub.status.busy": "2023-02-13T23:40:54.633929Z",
"iopub.status.idle": "2023-02-13T23:40:56.105700Z",
"shell.execute_reply": "2023-02-13T23:40:56.105085Z"
},
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [],
"source": [
"import autogen"
]
},
{
"attachments": {},
"cell_type": "markdown",
@ -109,7 +79,9 @@
"### Configure Authentication\n",
"\n",
"Authentication happens using standard [Google Cloud authentication methods](https://cloud.google.com/docs/authentication), <br/> which means\n",
"that either an already active session can be reused, or by specifying the Google application credentials of a service account.\n",
"that either an already active session can be reused, or by specifying the Google application credentials of a service account. <br/><br/>\n",
"Additionally, AutoGen also supports authentication using `Credentials` objects in Python with the [google-auth library](https://google-auth.readthedocs.io/), which enables even more flexibility.<br/>\n",
"For example, we can even use impersonated credentials.\n",
"\n",
"#### Use Service Account Keyfile\n",
"\n",
@ -121,7 +93,13 @@
"\n",
"If you are using [Cloud Shell](https://shell.cloud.google.com/cloudshell) or [Cloud Shell editor](https://shell.cloud.google.com/cloudshell/editor) in Google Cloud, <br/> then you are already authenticated. If you have the Google Cloud SDK installed locally, <br/> then you can login by running `gcloud auth login` in the command line. \n",
"\n",
"Detailed instructions for installing the Google Cloud SDK can be found [here](https://cloud.google.com/sdk/docs/install)."
"Detailed instructions for installing the Google Cloud SDK can be found [here](https://cloud.google.com/sdk/docs/install).\n",
"\n",
"#### Authentication with the Google Auth Library for Python\n",
"\n",
"The google-auth library supports a wide range of authentication scenarios, and you can simply pass a previously created `Credentials` object to the `llm_config`.<br/>\n",
"The [official documentation](https://google-auth.readthedocs.io/) of the Python package provides a detailed overview of the supported methods and usage examples.<br/>\n",
"If you are already authenticated, like in [Cloud Shell](https://shell.cloud.google.com/cloudshell), or after running the `gcloud auth login` command in a CLI, then the `google.auth.default()` Python method will automatically return your currently active credentials."
]
},
{
@ -147,7 +125,7 @@
" {\n",
" \"model\": \"gemini-1.5-pro\",\n",
" \"api_type\": \"google\",\n",
" \"project\": \"autogen-with-gemini\",\n",
" \"project_id\": \"autogen-with-gemini\",\n",
" \"location\": \"us-west1\",\n",
" \"google_application_credentials\": \"autogen-with-gemini-service-account-key.json\"\n",
" },\n",
@ -172,17 +150,11 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from vertexai.generative_models import (\n",
" GenerationConfig,\n",
" GenerativeModel,\n",
" HarmBlockThreshold,\n",
" HarmCategory,\n",
" Part,\n",
")\n",
"from vertexai.generative_models import HarmBlockThreshold, HarmCategory\n",
"\n",
"safety_settings = {\n",
" HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
@ -194,7 +166,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -205,6 +177,7 @@
"from PIL import Image\n",
"from termcolor import colored\n",
"\n",
"import autogen\n",
"from autogen import Agent, AssistantAgent, ConversableAgent, UserProxyAgent\n",
"from autogen.agentchat.contrib.img_utils import _to_pil, get_image_data\n",
"from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent\n",
@ -215,7 +188,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -242,7 +215,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -253,32 +226,27 @@
"\n",
"\n",
" Compute the integral of the function f(x)=x^2 on the interval 0 to 1 using a Python script,\n",
" which returns the value of the definite integral.\n",
" which returns the value of the definite integral\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"Plan:\n",
"1. (Code) Use Python's numerical integration library to compute the integral.\n",
"2. (Language) Output the result.\n",
"1. (code) Use Python's `scipy.integrate.quad` function to compute the integral. \n",
"\n",
"```python\n",
"# filename: integral.py\n",
"import scipy.integrate\n",
"from scipy.integrate import quad\n",
"\n",
"f = lambda x: x**2\n",
"result, error = scipy.integrate.quad(f, 0, 1)\n",
"def f(x):\n",
" return x**2\n",
"\n",
"result, error = quad(f, 0, 1)\n",
"\n",
"print(f\"The definite integral of x^2 from 0 to 1 is: {result}\")\n",
"```\n",
"\n",
"Let me know when you have executed the code. \n",
"Let me know when you have executed this code. \n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
@ -294,13 +262,11 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
"\n",
"The code executed successfully and returned the value of the definite integral as approximately 0.33333333333333337. \n",
"The script executed successfully and returned the definite integral's value as approximately 0.33333333333333337. \n",
"\n",
"This aligns with the analytical solution:\n",
"This aligns with the analytical solution. The indefinite integral of x^2 is (x^3)/3. Evaluating this from 0 to 1 gives us (1^3)/3 - (0^3)/3 = 1/3 = 0.33333...\n",
"\n",
"The integral of x^2 is (x^3)/3. Evaluating this from 0 to 1 gives us (1^3)/3 - (0^3)/3 = 1/3 = 0.33333...\n",
"\n",
"Therefore, the answer is verified to be correct.\n",
"Therefore, the script successfully computed the integral of x^2 from 0 to 1.\n",
"\n",
"TERMINATE\n",
"\n",
@ -325,7 +291,7 @@
" assistant,\n",
" message=\"\"\"\n",
" Compute the integral of the function f(x)=x^2 on the interval 0 to 1 using a Python script,\n",
" which returns the value of the definite integral.\"\"\",\n",
" which returns the value of the definite integral\"\"\",\n",
")"
]
},
@ -334,12 +300,52 @@
"metadata": {},
"source": [
"## Example with Gemini Multimodal\n",
"Authentication is the same for vision models as for the text based Gemini models"
"Authentication is the same for vision models as for the text based Gemini models. <br/>\n",
"In this example an object of type `Credentials` will be supplied in order to authenticate.<br/>\n",
"Here, we will use the google application default credentials, so make sure to run the following commands if you are not yet authenticated:\n",
"```bash\n",
"export GOOGLE_APPLICATION_CREDENTIALS=autogen-with-gemini-service-account-key.json\n",
"gcloud auth application-default login\n",
"gcloud config set project autogen-with-gemini\n",
"```\n",
"The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is a path to our service account JSON keyfile, as described in the [Use Service Account Keyfile](#Use Service Account Keyfile) section above.<br/>\n",
"We also need to set the Google cloud project, which is `autogen-with-gemini` in this example.<br/><br/>\n",
"\n",
"Note, we could also run `gcloud auth login` in case we wish to use our personal Google account instead of a service account.\n",
"In this case we need to run the following commands:\n",
"```bash\n",
"gcloud auth login\n",
"gcloud config set project autogen-with-gemini\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import google.auth\n",
"\n",
"scopes = [\"https://www.googleapis.com/auth/cloud-platform\"]\n",
"\n",
"credentials, project_id = google.auth.default(scopes)\n",
"\n",
"gemini_vision_config = [\n",
" {\n",
" \"model\": \"gemini-pro-vision\",\n",
" \"api_type\": \"google\",\n",
" \"project_id\": project_id,\n",
" \"credentials\": credentials,\n",
" \"location\": \"us-west1\",\n",
" \"safety_settings\": safety_settings,\n",
" }\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -356,7 +362,7 @@
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mGemini Vision\u001b[0m (to user_proxy):\n",
"\n",
" The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.\n",
" The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -364,17 +370,17 @@
{
"data": {
"text/plain": [
"ChatResult(chat_id=None, chat_history=[{'content': 'Describe what is in this image?\\n<img https://github.com/microsoft/autogen/blob/main/website/static/img/autogen_agentchat.png?raw=true>.', 'role': 'assistant'}, {'content': ' The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.', 'role': 'user'}], summary=' The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.', cost={'usage_including_cached_inference': {'total_cost': 0.0002385, 'gemini-pro-vision': {'cost': 0.0002385, 'prompt_tokens': 267, 'completion_tokens': 70, 'total_tokens': 337}}, 'usage_excluding_cached_inference': {'total_cost': 0.0002385, 'gemini-pro-vision': {'cost': 0.0002385, 'prompt_tokens': 267, 'completion_tokens': 70, 'total_tokens': 337}}}, human_input=[])"
"ChatResult(chat_id=None, chat_history=[{'content': 'Describe what is in this image?\\n<img https://github.com/microsoft/autogen/blob/main/website/static/img/autogen_agentchat.png?raw=true>.', 'role': 'assistant'}, {'content': \" The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\", 'role': 'user'}], summary=\" The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\", cost={'usage_including_cached_inference': {'total_cost': 0.0001995, 'gemini-pro-vision': {'cost': 0.0001995, 'prompt_tokens': 267, 'completion_tokens': 44, 'total_tokens': 311}}, 'usage_excluding_cached_inference': {'total_cost': 0}}, human_input=[])"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image_agent = MultimodalConversableAgent(\n",
" \"Gemini Vision\", llm_config={\"config_list\": config_list_gemini_vision, \"seed\": seed}, max_consecutive_auto_reply=1\n",
" \"Gemini Vision\", llm_config={\"config_list\": gemini_vision_config, \"seed\": seed}, max_consecutive_auto_reply=1\n",
")\n",
"\n",
"user_proxy = UserProxyAgent(\"user_proxy\", human_input_mode=\"NEVER\", max_consecutive_auto_reply=0)\n",
@ -415,7 +421,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
},
"vscode": {
"interpreter": {