mirror of https://github.com/microsoft/autogen.git
Enhance vertexai integration (#3086)
* switch to officially supported Vertex AI message sending + safety setting converion for vertexai * add system instructions * switch to officially supported Vertex AI message sending + safety setting converion for vertexai * fix bug in safety settings conversion * add missing system instructions * add safety settings to send message * add support for credentials objects * add type checkingchange project_id to project arg * add more tests * fix mock creation in test * extend docstring * fix errors with gemini message format in chats * add option for vertexai response validation setting & improve docstring * readding empty message handling * add more tests * extend and improve gemini vertexai jupyter notebook * rename project arg to project_id and GOOGLE_API_KEY env var to GOOGLE_GEMINI_API_KEY * adjust docstring formatting
This commit is contained in:
parent
1daf852f86
commit
a5e5be73b5
|
@ -107,6 +107,15 @@ def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite
|
|||
return chat_order
|
||||
|
||||
|
||||
def _post_process_carryover_item(carryover_item):
|
||||
if isinstance(carryover_item, str):
|
||||
return carryover_item
|
||||
elif isinstance(carryover_item, dict) and "content" in carryover_item:
|
||||
return str(carryover_item["content"])
|
||||
else:
|
||||
return str(carryover_item)
|
||||
|
||||
|
||||
def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
|
||||
iostream = IOStream.get_default()
|
||||
|
||||
|
@ -116,7 +125,7 @@ def __post_carryover_processing(chat_info: Dict[str, Any]) -> None:
|
|||
UserWarning,
|
||||
)
|
||||
print_carryover = (
|
||||
("\n").join([t for t in chat_info["carryover"]])
|
||||
("\n").join([_post_process_carryover_item(t) for t in chat_info["carryover"]])
|
||||
if isinstance(chat_info["carryover"], list)
|
||||
else chat_info["carryover"]
|
||||
)
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Ty
|
|||
|
||||
from openai import BadRequestError
|
||||
|
||||
from autogen.agentchat.chat import _post_process_carryover_item
|
||||
from autogen.exception_utils import InvalidCarryOverType, SenderRequired
|
||||
|
||||
from .._pydantic import model_dump
|
||||
|
@ -2364,7 +2365,7 @@ class ConversableAgent(LLMAgent):
|
|||
if isinstance(kwargs["carryover"], str):
|
||||
content += "\nContext: \n" + kwargs["carryover"]
|
||||
elif isinstance(kwargs["carryover"], list):
|
||||
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
|
||||
content += "\nContext: \n" + ("\n").join([_post_process_carryover_item(t) for t in kwargs["carryover"]])
|
||||
else:
|
||||
raise InvalidCarryOverType(
|
||||
"Carryover should be a string or a list of strings. Not adding carryover to the message."
|
||||
|
|
|
@ -6,7 +6,7 @@ Example:
|
|||
"config_list": [{
|
||||
"api_type": "google",
|
||||
"model": "gemini-pro",
|
||||
"api_key": os.environ.get("GOOGLE_API_KEY"),
|
||||
"api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
|
||||
"safety_settings": [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
|
||||
|
@ -32,6 +32,7 @@ Resources:
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
@ -45,13 +46,19 @@ import requests
|
|||
import vertexai
|
||||
from google.ai.generativelanguage import Content, Part
|
||||
from google.api_core.exceptions import InternalServerError
|
||||
from google.auth.credentials import Credentials
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from PIL import Image
|
||||
from vertexai.generative_models import Content as VertexAIContent
|
||||
from vertexai.generative_models import GenerativeModel
|
||||
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
|
||||
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
|
||||
from vertexai.generative_models import Part as VertexAIPart
|
||||
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
|
@ -81,29 +88,36 @@ class GeminiClient:
|
|||
vertexai_init_args["project"] = params["project_id"]
|
||||
if "location" in params:
|
||||
vertexai_init_args["location"] = params["location"]
|
||||
if "credentials" in params:
|
||||
assert isinstance(
|
||||
params["credentials"], Credentials
|
||||
), "Object type google.auth.credentials.Credentials is expected!"
|
||||
vertexai_init_args["credentials"] = params["credentials"]
|
||||
if vertexai_init_args:
|
||||
vertexai.init(**vertexai_init_args)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Uses either either api_key for authentication from the LLM config
|
||||
(specifying the GOOGLE_API_KEY environment variable also works),
|
||||
(specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
|
||||
or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
|
||||
where project_id and location can also be passed as parameters. Service account key file can also be used.
|
||||
If neither a service account key file, nor the api_key are passed, then the default credentials will be used,
|
||||
which could be a personal account if the user is already authenticated in, like in Google Cloud Shell.
|
||||
where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
|
||||
or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
|
||||
then the default credentials will be used, which could be a personal account if the user is already authenticated in,
|
||||
like in Google Cloud Shell.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for using Gemini.
|
||||
credentials (google.auth.credentials.Credentials): credentials to be used for authentication with vertexai.
|
||||
google_application_credentials (str): Path to the JSON service account key file of the service account.
|
||||
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
||||
can also be set instead of using this argument.
|
||||
Alternatively, the GOOGLE_APPLICATION_CREDENTIALS environment variable
|
||||
can also be set instead of using this argument.
|
||||
project_id (str): Google Cloud project id, which is only valid in case no API key is specified.
|
||||
location (str): Compute region to be used, like 'us-west1'.
|
||||
This parameter is only valid in case no API key is specified.
|
||||
This parameter is only valid in case no API key is specified.
|
||||
"""
|
||||
self.api_key = kwargs.get("api_key", None)
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("GOOGLE_API_KEY")
|
||||
self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
|
||||
if self.api_key is None:
|
||||
self.use_vertexai = True
|
||||
self._initialize_vertexai(**kwargs)
|
||||
|
@ -159,13 +173,18 @@ class GeminiClient:
|
|||
messages = params.get("messages", [])
|
||||
stream = params.get("stream", False)
|
||||
n_response = params.get("n", 1)
|
||||
system_instruction = params.get("system_instruction", None)
|
||||
response_validation = params.get("response_validation", True)
|
||||
|
||||
generation_config = {
|
||||
gemini_term: params[autogen_term]
|
||||
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
|
||||
if autogen_term in params
|
||||
}
|
||||
safety_settings = params.get("safety_settings", {})
|
||||
if self.use_vertexai:
|
||||
safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", {}))
|
||||
else:
|
||||
safety_settings = params.get("safety_settings", {})
|
||||
|
||||
if stream:
|
||||
warnings.warn(
|
||||
|
@ -181,20 +200,29 @@ class GeminiClient:
|
|||
gemini_messages = self._oai_messages_to_gemini_messages(messages)
|
||||
if self.use_vertexai:
|
||||
model = GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
model_name,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
|
||||
else:
|
||||
# we use chat model by default
|
||||
model = genai.GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
model_name,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
genai.configure(api_key=self.api_key)
|
||||
chat = model.start_chat(history=gemini_messages[:-1])
|
||||
chat = model.start_chat(history=gemini_messages[:-1])
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
ans = None
|
||||
try:
|
||||
response = chat.send_message(gemini_messages[-1], stream=stream)
|
||||
response = chat.send_message(
|
||||
gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings
|
||||
)
|
||||
except InternalServerError:
|
||||
delay = 5 * (2**attempt)
|
||||
warnings.warn(
|
||||
|
@ -218,16 +246,22 @@ class GeminiClient:
|
|||
# B. handle the vision model
|
||||
if self.use_vertexai:
|
||||
model = GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
model_name,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
else:
|
||||
model = genai.GenerativeModel(
|
||||
model_name, generation_config=generation_config, safety_settings=safety_settings
|
||||
model_name,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
genai.configure(api_key=self.api_key)
|
||||
# Gemini's vision model does not support chat history yet
|
||||
# chat = model.start_chat(history=gemini_messages[:-1])
|
||||
# response = chat.send_message(gemini_messages[-1])
|
||||
# response = chat.send_message(gemini_messages[-1].parts)
|
||||
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
|
||||
if len(messages) > 2:
|
||||
warnings.warn(
|
||||
|
@ -270,6 +304,8 @@ class GeminiClient:
|
|||
"""Convert content from OAI format to Gemini format"""
|
||||
rst = []
|
||||
if isinstance(content, str):
|
||||
if content == "":
|
||||
content = "empty" # Empty content is not allowed.
|
||||
if self.use_vertexai:
|
||||
rst.append(VertexAIPart.from_text(content))
|
||||
else:
|
||||
|
@ -372,6 +408,35 @@ class GeminiClient:
|
|||
|
||||
return rst
|
||||
|
||||
@staticmethod
|
||||
def _to_vertexai_safety_settings(safety_settings):
|
||||
"""Convert safety settings to VertexAI format if needed,
|
||||
like when specifying them in the OAI_CONFIG_LIST
|
||||
"""
|
||||
if isinstance(safety_settings, list) and all(
|
||||
[
|
||||
isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
|
||||
for safety_setting in safety_settings
|
||||
]
|
||||
):
|
||||
vertexai_safety_settings = []
|
||||
for safety_setting in safety_settings:
|
||||
if safety_setting["category"] not in VertexAIHarmCategory.__members__:
|
||||
invalid_category = safety_setting["category"]
|
||||
logger.error(f"Safety setting category {invalid_category} is invalid")
|
||||
elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
|
||||
invalid_threshold = safety_setting["threshold"]
|
||||
logger.error(f"Safety threshold {invalid_threshold} is invalid")
|
||||
else:
|
||||
vertexai_safety_setting = VertexAISafetySetting(
|
||||
category=safety_setting["category"],
|
||||
threshold=safety_setting["threshold"],
|
||||
)
|
||||
vertexai_safety_settings.append(vertexai_safety_setting)
|
||||
return vertexai_safety_settings
|
||||
else:
|
||||
return safety_settings
|
||||
|
||||
|
||||
def _to_pil(data: str) -> Image.Image:
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,15 @@ from openai import OpenAI
|
|||
from openai.types.beta.assistant import Assistant
|
||||
from packaging.version import parse
|
||||
|
||||
NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version", "azure_ad_token", "azure_ad_token_provider"]
|
||||
NON_CACHE_KEY = [
|
||||
"api_key",
|
||||
"base_url",
|
||||
"api_type",
|
||||
"api_version",
|
||||
"azure_ad_token",
|
||||
"azure_ad_token_provider",
|
||||
"credentials",
|
||||
]
|
||||
DEFAULT_AZURE_API_VERSION = "2024-02-01"
|
||||
OAI_PRICE1K = {
|
||||
# https://openai.com/api/pricing/
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing_extensions import Annotated
|
|||
|
||||
import autogen
|
||||
from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent, filter_config, initiate_chats
|
||||
from autogen.agentchat.chat import _post_process_carryover_item
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
from conftest import reason, skip_openai # noqa: E402
|
||||
|
@ -620,6 +621,15 @@ def test_udf_message_in_chats():
|
|||
print(chat_results[1].summary, chat_results[1].cost)
|
||||
|
||||
|
||||
def test_post_process_carryover_item():
|
||||
gemini_carryover_item = {"content": "How can I help you?", "role": "model"}
|
||||
assert (
|
||||
_post_process_carryover_item(gemini_carryover_item) == gemini_carryover_item["content"]
|
||||
), "Incorrect carryover postprocessing"
|
||||
carryover_item = "How can I help you?"
|
||||
assert _post_process_carryover_item(carryover_item) == carryover_item, "Incorrect carryover postprocessing"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chats()
|
||||
# test_chats_general()
|
||||
|
@ -628,3 +638,4 @@ if __name__ == "__main__":
|
|||
# test_chats_w_func()
|
||||
# test_chat_messages_for_summary()
|
||||
# test_udf_message_in_chats()
|
||||
test_post_process_carryover_item()
|
||||
|
|
|
@ -1463,6 +1463,58 @@ def test_adding_duplicate_function_warning():
|
|||
)
|
||||
|
||||
|
||||
def test_process_gemini_carryover():
|
||||
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
|
||||
content = "I am your assistant."
|
||||
carryover_content = "How can I help you?"
|
||||
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
|
||||
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=gemini_kwargs)
|
||||
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"
|
||||
|
||||
|
||||
def test_process_carryover():
|
||||
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
|
||||
content = "I am your assistant."
|
||||
carryover = "How can I help you?"
|
||||
kwargs = {"carryover": carryover}
|
||||
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
|
||||
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
|
||||
|
||||
carryover_l = ["How can I help you?"]
|
||||
kwargs = {"carryover": carryover_l}
|
||||
proc_content = dummy_agent_1._process_carryover(content=content, kwargs=kwargs)
|
||||
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
|
||||
|
||||
proc_content_empty_carryover = dummy_agent_1._process_carryover(content=content, kwargs={"carryover": None})
|
||||
assert proc_content_empty_carryover == content, "Incorrect carryover processing"
|
||||
|
||||
|
||||
def test_handle_gemini_carryover():
|
||||
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
|
||||
content = "I am your assistant"
|
||||
carryover_content = "How can I help you?"
|
||||
gemini_kwargs = {"carryover": [{"content": carryover_content}]}
|
||||
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=gemini_kwargs)
|
||||
assert proc_content == content + "\nContext: \n" + carryover_content, "Incorrect carryover processing"
|
||||
|
||||
|
||||
def test_handle_carryover():
|
||||
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
|
||||
content = "I am your assistant."
|
||||
carryover = "How can I help you?"
|
||||
kwargs = {"carryover": carryover}
|
||||
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
|
||||
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
|
||||
|
||||
carryover_l = ["How can I help you?"]
|
||||
kwargs = {"carryover": carryover_l}
|
||||
proc_content = dummy_agent_1._handle_carryover(message=content, kwargs=kwargs)
|
||||
assert proc_content == content + "\nContext: \n" + carryover, "Incorrect carryover processing"
|
||||
|
||||
proc_content_empty_carryover = dummy_agent_1._handle_carryover(message=content, kwargs={"carryover": None})
|
||||
assert proc_content_empty_carryover == content, "Incorrect carryover processing"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_trigger()
|
||||
# test_context()
|
||||
|
@ -1473,6 +1525,10 @@ if __name__ == "__main__":
|
|||
# test_max_turn()
|
||||
# test_process_before_send()
|
||||
# test_message_func()
|
||||
|
||||
test_summary()
|
||||
test_adding_duplicate_function_warning()
|
||||
# test_function_registration_e2e_sync()
|
||||
|
||||
test_process_gemini_carryover()
|
||||
test_process_carryover()
|
||||
|
|
|
@ -1,15 +1,26 @@
|
|||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import google.auth
|
||||
from google.api_core.exceptions import InternalServerError
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud.aiplatform.initializer import global_config as vertexai_global_config
|
||||
from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
|
||||
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
|
||||
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
|
||||
|
||||
from autogen.oai.gemini import GeminiClient
|
||||
|
||||
skip = False
|
||||
except ImportError:
|
||||
GeminiClient = object
|
||||
VertexAIHarmBlockThreshold = object
|
||||
VertexAIHarmCategory = object
|
||||
VertexAISafetySetting = object
|
||||
vertexai_global_config = object
|
||||
InternalServerError = object
|
||||
skip = True
|
||||
|
||||
|
@ -30,7 +41,24 @@ def mock_response():
|
|||
|
||||
@pytest.fixture
|
||||
def gemini_client():
|
||||
return GeminiClient(api_key="fake_api_key")
|
||||
system_message = [
|
||||
"You are a helpful AI assistant.",
|
||||
]
|
||||
return GeminiClient(api_key="fake_api_key", system_message=system_message)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_google_auth_default_client():
|
||||
system_message = [
|
||||
"You are a helpful AI assistant.",
|
||||
]
|
||||
return GeminiClient(system_message=system_message)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_client_with_credentials():
|
||||
mock_credentials = MagicMock(Credentials)
|
||||
return GeminiClient(credentials=mock_credentials)
|
||||
|
||||
|
||||
# Test compute location initialization and configuration
|
||||
|
@ -42,9 +70,13 @@ def test_compute_location_initialization():
|
|||
) # Should raise an AssertionError due to specifying API key and compute location
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_google_auth_default_client():
|
||||
return GeminiClient()
|
||||
# Test project initialization and configuration
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_project_initialization():
|
||||
with pytest.raises(AssertionError):
|
||||
GeminiClient(
|
||||
api_key="fake_api_key", project_id="fake-project-id"
|
||||
) # Should raise an AssertionError due to specifying API key and compute location
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
|
@ -52,6 +84,23 @@ def test_valid_initialization(gemini_client):
|
|||
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_google_application_credentials_initialization():
|
||||
GeminiClient(google_application_credentials="credentials.json", project_id="fake-project-id")
|
||||
assert (
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] == "credentials.json"
|
||||
), "Incorrect Google Application Credentials initialization"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_vertexai_initialization():
|
||||
mock_credentials = MagicMock(Credentials)
|
||||
GeminiClient(credentials=mock_credentials, project_id="fake-project-id", location="us-west1")
|
||||
assert vertexai_global_config.location == "us-west1", "Incorrect VertexAI location initialization"
|
||||
assert vertexai_global_config.project == "fake-project-id", "Incorrect VertexAI project initialization"
|
||||
assert vertexai_global_config.credentials == mock_credentials, "Incorrect VertexAI credentials initialization"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_gemini_message_handling(gemini_client):
|
||||
messages = [
|
||||
|
@ -94,6 +143,113 @@ def test_gemini_message_handling(gemini_client):
|
|||
assert converted_messages[i].parts[j].text == part, "Incorrect mapped message text"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_gemini_empty_message_handling(gemini_client):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are my personal assistant."},
|
||||
{"role": "model", "content": "How can I help you?"},
|
||||
{"role": "user", "content": ""},
|
||||
{
|
||||
"role": "model",
|
||||
"content": "Please provide me with some context or a request! I need more information to assist you.",
|
||||
},
|
||||
{"role": "user", "content": ""},
|
||||
]
|
||||
|
||||
converted_messages = gemini_client._oai_messages_to_gemini_messages(messages)
|
||||
assert converted_messages[-3].parts[0].text == "empty", "Empty message is not converted to 'empty' correctly"
|
||||
assert converted_messages[-1].parts[0].text == "empty", "Empty message is not converted to 'empty' correctly"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_vertexai_safety_setting_conversion(gemini_client):
|
||||
safety_settings = [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
|
||||
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
|
||||
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
|
||||
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
|
||||
]
|
||||
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
|
||||
harm_categories = [
|
||||
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
]
|
||||
expected_safety_settings = [
|
||||
VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH)
|
||||
for category in harm_categories
|
||||
]
|
||||
|
||||
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
|
||||
for i, expected_setting in enumerate(expected_safety_settings):
|
||||
converted_setting = converted_safety_settings[i]
|
||||
yield expected_setting.to_dict() == converted_setting.to_dict()
|
||||
|
||||
assert len(converted_safety_settings) == len(
|
||||
expected_safety_settings
|
||||
), "The length of the safety settings is incorrect"
|
||||
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
|
||||
assert all(settings_comparison), "Converted safety settings are incorrect"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_vertexai_default_safety_settings_dict(gemini_client):
|
||||
safety_settings = {
|
||||
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
}
|
||||
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
|
||||
|
||||
expected_safety_settings = {
|
||||
category: VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH for category in safety_settings.keys()
|
||||
}
|
||||
|
||||
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
|
||||
for expected_setting_key in expected_safety_settings.keys():
|
||||
expected_setting = expected_safety_settings[expected_setting_key]
|
||||
converted_setting = converted_safety_settings[expected_setting_key]
|
||||
yield expected_setting == converted_setting
|
||||
|
||||
assert len(converted_safety_settings) == len(
|
||||
expected_safety_settings
|
||||
), "The length of the safety settings is incorrect"
|
||||
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
|
||||
assert all(settings_comparison), "Converted safety settings are incorrect"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
def test_vertexai_safety_setting_list(gemini_client):
|
||||
harm_categories = [
|
||||
VertexAIHarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
VertexAIHarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
]
|
||||
|
||||
expected_safety_settings = safety_settings = [
|
||||
VertexAISafetySetting(category=category, threshold=VertexAIHarmBlockThreshold.BLOCK_ONLY_HIGH)
|
||||
for category in harm_categories
|
||||
]
|
||||
|
||||
print(safety_settings)
|
||||
|
||||
converted_safety_settings = GeminiClient._to_vertexai_safety_settings(safety_settings)
|
||||
|
||||
def compare_safety_settings(converted_safety_settings, expected_safety_settings):
|
||||
for i, expected_setting in enumerate(expected_safety_settings):
|
||||
converted_setting = converted_safety_settings[i]
|
||||
yield expected_setting.to_dict() == converted_setting.to_dict()
|
||||
|
||||
assert len(converted_safety_settings) == len(
|
||||
expected_safety_settings
|
||||
), "The length of the safety settings is incorrect"
|
||||
settings_comparison = compare_safety_settings(converted_safety_settings, expected_safety_settings)
|
||||
assert all(settings_comparison), "Converted safety settings are incorrect"
|
||||
|
||||
|
||||
# Test error handling
|
||||
@patch("autogen.oai.gemini.genai")
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
|
@ -150,6 +306,62 @@ def test_create_response(mock_configure, mock_generative_model, gemini_client):
|
|||
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.vertexai.init")
|
||||
def test_vertexai_create_response(mock_init, mock_generative_model, gemini_client_with_credentials):
|
||||
# Mock the genai model configuration and creation process
|
||||
mock_chat = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_init.return_value = None
|
||||
mock_generative_model.return_value = mock_model
|
||||
mock_model.start_chat.return_value = mock_chat
|
||||
|
||||
# Set up a mock for the chat history item access and the text attribute return
|
||||
mock_history_part = MagicMock()
|
||||
mock_history_part.text = "Example response"
|
||||
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
|
||||
|
||||
# Setup the mock to return a mocked chat response
|
||||
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])
|
||||
|
||||
# Call the create method
|
||||
response = gemini_client_with_credentials.create(
|
||||
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
|
||||
)
|
||||
|
||||
# Assertions to check if response is structured as expected
|
||||
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.vertexai.init")
|
||||
def test_vertexai_default_auth_create_response(mock_init, mock_generative_model, gemini_google_auth_default_client):
|
||||
# Mock the genai model configuration and creation process
|
||||
mock_chat = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_init.return_value = None
|
||||
mock_generative_model.return_value = mock_model
|
||||
mock_model.start_chat.return_value = mock_chat
|
||||
|
||||
# Set up a mock for the chat history item access and the text attribute return
|
||||
mock_history_part = MagicMock()
|
||||
mock_history_part.text = "Example response"
|
||||
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
|
||||
|
||||
# Setup the mock to return a mocked chat response
|
||||
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])
|
||||
|
||||
# Call the create method
|
||||
response = gemini_google_auth_default_client.create(
|
||||
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
|
||||
)
|
||||
|
||||
# Assertions to check if response is structured as expected
|
||||
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.genai.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.genai.configure")
|
||||
|
@ -195,3 +407,49 @@ def test_create_vision_model_response(mock_configure, mock_generative_model, gem
|
|||
assert (
|
||||
response.choices[0].message.content == "Vision model output"
|
||||
), "Response content should match expected output from vision model"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
|
||||
@patch("autogen.oai.gemini.GenerativeModel")
|
||||
@patch("autogen.oai.gemini.vertexai.init")
|
||||
def test_vertexai_create_vision_model_response(mock_init, mock_generative_model, gemini_google_auth_default_client):
|
||||
# Mock the genai model configuration and creation process
|
||||
mock_model = MagicMock()
|
||||
mock_init.return_value = None
|
||||
mock_generative_model.return_value = mock_model
|
||||
|
||||
# Set up a mock to simulate the vision model behavior
|
||||
mock_vision_response = MagicMock()
|
||||
mock_vision_part = MagicMock(text="Vision model output")
|
||||
|
||||
# Setting up the chain of return values for vision model response
|
||||
mock_vision_response.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = mock_vision_part
|
||||
|
||||
mock_model.generate_content.return_value = mock_vision_response
|
||||
|
||||
# Call the create method with vision model parameters
|
||||
response = gemini_google_auth_default_client.create(
|
||||
{
|
||||
"model": "gemini-pro-vision", # Vision model name
|
||||
"messages": [
|
||||
{
|
||||
"content": [
|
||||
{"type": "text", "text": "Let's play a game."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
|
||||
},
|
||||
},
|
||||
],
|
||||
"role": "user",
|
||||
}
|
||||
], # Assuming a simple content input for vision
|
||||
"stream": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Assertions to check if response is structured as expected
|
||||
assert (
|
||||
response.choices[0].message.content == "Vision model output"
|
||||
), "Response content should match expected output from vision model"
|
||||
|
|
|
@ -64,7 +64,7 @@
|
|||
" },\n",
|
||||
" {\n",
|
||||
" \"model\": \"gemini-1.5-pro\",\n",
|
||||
" \"project\": \"your-awesome-google-cloud-project-id\",\n",
|
||||
" \"project_id\": \"your-awesome-google-cloud-project-id\",\n",
|
||||
" \"location\": \"us-west1\",\n",
|
||||
" \"google_application_credentials\": \"your-google-service-account-key.json\"\n",
|
||||
" },\n",
|
||||
|
|
|
@ -14,9 +14,14 @@
|
|||
"\n",
|
||||
"## Requirements\n",
|
||||
"\n",
|
||||
"AutoGen requires `Python>=3.8`. To run this notebook example, please install with the [gemini] option:\n",
|
||||
"Install AutoGen with Gemini features:\n",
|
||||
"```bash\n",
|
||||
"pip install \"pyautogen[gemini]\"\n",
|
||||
"pip install pyautogen[gemini]\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Install other Dependencies of this Notebook\n",
|
||||
"```bash\n",
|
||||
"pip install chromadb markdownify pypdf\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Google Cloud Account\n",
|
||||
|
@ -66,41 +71,6 @@
|
|||
" * Please consider restricting the permissions on the key file. For example, you could run `chmod 600 autogen-with-gemini-service-account-key.json` if your keyfile is called autogen-with-gemini-service-account-key.json."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-02-13T23:40:52.317406Z",
|
||||
"iopub.status.busy": "2023-02-13T23:40:52.316561Z",
|
||||
"iopub.status.idle": "2023-02-13T23:40:52.321193Z",
|
||||
"shell.execute_reply": "2023-02-13T23:40:52.320628Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# %pip install \"pyautogen[gemini]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2023-02-13T23:40:54.634335Z",
|
||||
"iopub.status.busy": "2023-02-13T23:40:54.633929Z",
|
||||
"iopub.status.idle": "2023-02-13T23:40:56.105700Z",
|
||||
"shell.execute_reply": "2023-02-13T23:40:56.105085Z"
|
||||
},
|
||||
"slideshow": {
|
||||
"slide_type": "slide"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import autogen"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
|
@ -109,7 +79,9 @@
|
|||
"### Configure Authentication\n",
|
||||
"\n",
|
||||
"Authentication happens using standard [Google Cloud authentication methods](https://cloud.google.com/docs/authentication), <br/> which means\n",
|
||||
"that either an already active session can be reused, or by specifying the Google application credentials of a service account.\n",
|
||||
"that either an already active session can be reused, or by specifying the Google application credentials of a service account. <br/><br/>\n",
|
||||
"Additionally, AutoGen also supports authentication using `Credentials` objects in Python with the [google-auth library](https://google-auth.readthedocs.io/), which enables even more flexibility.<br/>\n",
|
||||
"For example, we can even use impersonated credentials.\n",
|
||||
"\n",
|
||||
"#### Use Service Account Keyfile\n",
|
||||
"\n",
|
||||
|
@ -121,7 +93,13 @@
|
|||
"\n",
|
||||
"If you are using [Cloud Shell](https://shell.cloud.google.com/cloudshell) or [Cloud Shell editor](https://shell.cloud.google.com/cloudshell/editor) in Google Cloud, <br/> then you are already authenticated. If you have the Google Cloud SDK installed locally, <br/> then you can login by running `gcloud auth login` in the command line. \n",
|
||||
"\n",
|
||||
"Detailed instructions for installing the Google Cloud SDK can be found [here](https://cloud.google.com/sdk/docs/install)."
|
||||
"Detailed instructions for installing the Google Cloud SDK can be found [here](https://cloud.google.com/sdk/docs/install).\n",
|
||||
"\n",
|
||||
"#### Authentication with the Google Auth Library for Python\n",
|
||||
"\n",
|
||||
"The google-auth library supports a wide range of authentication scenarios, and you can simply pass a previously created `Credentials` object to the `llm_config`.<br/>\n",
|
||||
"The [official documentation](https://google-auth.readthedocs.io/) of the Python package provides a detailed overview of the supported methods and usage examples.<br/>\n",
|
||||
"If you are already authenticated, like in [Cloud Shell](https://shell.cloud.google.com/cloudshell), or after running the `gcloud auth login` command in a CLI, then the `google.auth.default()` Python method will automatically return your currently active credentials."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -147,7 +125,7 @@
|
|||
" {\n",
|
||||
" \"model\": \"gemini-1.5-pro\",\n",
|
||||
" \"api_type\": \"google\",\n",
|
||||
" \"project\": \"autogen-with-gemini\",\n",
|
||||
" \"project_id\": \"autogen-with-gemini\",\n",
|
||||
" \"location\": \"us-west1\",\n",
|
||||
" \"google_application_credentials\": \"autogen-with-gemini-service-account-key.json\"\n",
|
||||
" },\n",
|
||||
|
@ -172,17 +150,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from vertexai.generative_models import (\n",
|
||||
" GenerationConfig,\n",
|
||||
" GenerativeModel,\n",
|
||||
" HarmBlockThreshold,\n",
|
||||
" HarmCategory,\n",
|
||||
" Part,\n",
|
||||
")\n",
|
||||
"from vertexai.generative_models import HarmBlockThreshold, HarmCategory\n",
|
||||
"\n",
|
||||
"safety_settings = {\n",
|
||||
" HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,\n",
|
||||
|
@ -194,7 +166,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -205,6 +177,7 @@
|
|||
"from PIL import Image\n",
|
||||
"from termcolor import colored\n",
|
||||
"\n",
|
||||
"import autogen\n",
|
||||
"from autogen import Agent, AssistantAgent, ConversableAgent, UserProxyAgent\n",
|
||||
"from autogen.agentchat.contrib.img_utils import _to_pil, get_image_data\n",
|
||||
"from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent\n",
|
||||
|
@ -215,7 +188,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -242,7 +215,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -252,33 +225,28 @@
|
|||
"\u001b[33muser_proxy\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" Compute the integral of the function f(x)=x^2 on the interval 0 to 1 using a Python script, \n",
|
||||
" which returns the value of the definite integral.\n",
|
||||
" Compute the integral of the function f(x)=x^2 on the interval 0 to 1 using a Python script,\n",
|
||||
" which returns the value of the definite integral\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"Plan:\n",
|
||||
"1. (Code) Use Python's numerical integration library to compute the integral.\n",
|
||||
"2. (Language) Output the result.\n",
|
||||
"1. (code) Use Python's `scipy.integrate.quad` function to compute the integral. \n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"# filename: integral.py\n",
|
||||
"import scipy.integrate\n",
|
||||
"from scipy.integrate import quad\n",
|
||||
"\n",
|
||||
"f = lambda x: x**2\n",
|
||||
"result, error = scipy.integrate.quad(f, 0, 1)\n",
|
||||
"def f(x):\n",
|
||||
" return x**2\n",
|
||||
"\n",
|
||||
"result, error = quad(f, 0, 1)\n",
|
||||
"\n",
|
||||
"print(f\"The definite integral of x^2 from 0 to 1 is: {result}\")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Let me know when you have executed the code. \n",
|
||||
"Let me know when you have executed this code. \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
|
@ -294,13 +262,11 @@
|
|||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
"The code executed successfully and returned the value of the definite integral as approximately 0.33333333333333337. \n",
|
||||
"The script executed successfully and returned the definite integral's value as approximately 0.33333333333333337. \n",
|
||||
"\n",
|
||||
"This aligns with the analytical solution:\n",
|
||||
"This aligns with the analytical solution. The indefinite integral of x^2 is (x^3)/3. Evaluating this from 0 to 1 gives us (1^3)/3 - (0^3)/3 = 1/3 = 0.33333...\n",
|
||||
"\n",
|
||||
"The integral of x^2 is (x^3)/3. Evaluating this from 0 to 1 gives us (1^3)/3 - (0^3)/3 = 1/3 = 0.33333...\n",
|
||||
"\n",
|
||||
"Therefore, the answer is verified to be correct.\n",
|
||||
"Therefore, the script successfully computed the integral of x^2 from 0 to 1.\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
|
@ -325,7 +291,7 @@
|
|||
" assistant,\n",
|
||||
" message=\"\"\"\n",
|
||||
" Compute the integral of the function f(x)=x^2 on the interval 0 to 1 using a Python script,\n",
|
||||
" which returns the value of the definite integral.\"\"\",\n",
|
||||
" which returns the value of the definite integral\"\"\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
@ -334,12 +300,52 @@
|
|||
"metadata": {},
|
||||
"source": [
|
||||
"## Example with Gemini Multimodal\n",
|
||||
"Authentication is the same for vision models as for the text based Gemini models"
|
||||
"Authentication is the same for vision models as for the text based Gemini models. <br/>\n",
|
||||
"In this example an object of type `Credentials` will be supplied in order to authenticate.<br/>\n",
|
||||
"Here, we will use the google application default credentials, so make sure to run the following commands if you are not yet authenticated:\n",
|
||||
"```bash\n",
|
||||
"export GOOGLE_APPLICATION_CREDENTIALS=autogen-with-gemini-service-account-key.json\n",
|
||||
"gcloud auth application-default login\n",
|
||||
"gcloud config set project autogen-with-gemini\n",
|
||||
"```\n",
|
||||
"The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is a path to our service account JSON keyfile, as described in the [Use Service Account Keyfile](#Use Service Account Keyfile) section above.<br/>\n",
|
||||
"We also need to set the Google cloud project, which is `autogen-with-gemini` in this example.<br/><br/>\n",
|
||||
"\n",
|
||||
"Note, we could also run `gcloud auth login` in case we wish to use our personal Google account instead of a service account.\n",
|
||||
"In this case we need to run the following commands:\n",
|
||||
"```bash\n",
|
||||
"gcloud auth login\n",
|
||||
"gcloud config set project autogen-with-gemini\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import google.auth\n",
|
||||
"\n",
|
||||
"scopes = [\"https://www.googleapis.com/auth/cloud-platform\"]\n",
|
||||
"\n",
|
||||
"credentials, project_id = google.auth.default(scopes)\n",
|
||||
"\n",
|
||||
"gemini_vision_config = [\n",
|
||||
" {\n",
|
||||
" \"model\": \"gemini-pro-vision\",\n",
|
||||
" \"api_type\": \"google\",\n",
|
||||
" \"project_id\": project_id,\n",
|
||||
" \"credentials\": credentials,\n",
|
||||
" \"location\": \"us-west1\",\n",
|
||||
" \"safety_settings\": safety_settings,\n",
|
||||
" }\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -356,7 +362,7 @@
|
|||
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
|
||||
"\u001b[33mGemini Vision\u001b[0m (to user_proxy):\n",
|
||||
"\n",
|
||||
" The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.\n",
|
||||
" The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
|
@ -364,17 +370,17 @@
|
|||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ChatResult(chat_id=None, chat_history=[{'content': 'Describe what is in this image?\\n<img https://github.com/microsoft/autogen/blob/main/website/static/img/autogen_agentchat.png?raw=true>.', 'role': 'assistant'}, {'content': ' The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.', 'role': 'user'}], summary=' The image shows a taxonomy of different types of conversational agents. The taxonomy is based on two dimensions: agent customization and flexible conversation patterns. Agent customization refers to the ability of the agent to be tailored to the individual user. Flexible conversation patterns refer to the ability of the agent to engage in different types of conversations, such as joint chat and hierarchical chat.', cost={'usage_including_cached_inference': {'total_cost': 0.0002385, 'gemini-pro-vision': {'cost': 0.0002385, 'prompt_tokens': 267, 'completion_tokens': 70, 'total_tokens': 337}}, 'usage_excluding_cached_inference': {'total_cost': 0.0002385, 'gemini-pro-vision': {'cost': 0.0002385, 'prompt_tokens': 267, 'completion_tokens': 70, 'total_tokens': 337}}}, human_input=[])"
|
||||
"ChatResult(chat_id=None, chat_history=[{'content': 'Describe what is in this image?\\n<img https://github.com/microsoft/autogen/blob/main/website/static/img/autogen_agentchat.png?raw=true>.', 'role': 'assistant'}, {'content': \" The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\", 'role': 'user'}], summary=\" The image describes a conversational agent that is able to have a conversation with a human user. The agent can be customized to the user's preferences. The conversation can be in form of a joint chat or hierarchical chat.\", cost={'usage_including_cached_inference': {'total_cost': 0.0001995, 'gemini-pro-vision': {'cost': 0.0001995, 'prompt_tokens': 267, 'completion_tokens': 44, 'total_tokens': 311}}, 'usage_excluding_cached_inference': {'total_cost': 0}}, human_input=[])"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"image_agent = MultimodalConversableAgent(\n",
|
||||
" \"Gemini Vision\", llm_config={\"config_list\": config_list_gemini_vision, \"seed\": seed}, max_consecutive_auto_reply=1\n",
|
||||
" \"Gemini Vision\", llm_config={\"config_list\": gemini_vision_config, \"seed\": seed}, max_consecutive_auto_reply=1\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"user_proxy = UserProxyAgent(\"user_proxy\", human_input_mode=\"NEVER\", max_consecutive_auto_reply=0)\n",
|
||||
|
@ -415,7 +421,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.14"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
|
Loading…
Reference in New Issue