Merge "Gemini" feature into the main branch (#2360)

* Start Gemini integration: works ok with Text now

* Gemini notebook lint

* try catch "import" for Gemini

* Debug: id issue for chat completion in Gemini

* Add RAG example

* Update docs for RAG

* Fix missing pydash

* Remove temp folder

* Fix test error in runs/7206014032/job/19630042864

* Fix tqdm warning

* Fix notebook output

* Gemini's vision model is supported now

* Install instructions for the Gemini branch

* Catch and retry when see Interval Server Error 500

* Allow gemini to take more flexible messages
i.e., it can take messages where "user" is not the last role.

* Use int time for Gemini client

* Handle other exceptions in gemini call

* rename to "create" function for gemini

* GeminiClient compatible with ModelClient now

* Lint

* Update instructions in Gemini notebook

* Lint

* Remove empty blocks from Gemini notebook

* Add gemini into example page

* self.create instead of call

* Add py and Py into python execution

* Remove error code from merging

* Remove pydash dependency for gemini

* Add cloud-gemini doc

* Remove temp file

* cache import update

* Add test case for summary with mm input

* Lint: warnings instead of print

* Add test cases for gemini

* Gemini test config

* Disable default model for gemini

* Typo fix in gemini workflow

* Correct grammar in example notebook

* Raise if "model" is not provided in create(...)

* Move TODOs into a roadmap

* Update .github/workflows/contrib-tests.yml

Co-authored-by: Davor Runje <davor@airt.ai>

* Gemini test config update

* Update setup.py

Co-authored-by: Davor Runje <davor@airt.ai>

* Update test/oai/test_gemini.py

Co-authored-by: Davor Runje <davor@airt.ai>

* Update test/oai/test_gemini.py

Co-authored-by: Davor Runje <davor@airt.ai>

* Remove python 3.8 from gemini
No google's generativeai for Windows with Python 3.8

* Update import error handling for gemini

* Count tokens and cost for gemini

---------

Co-authored-by: Li Jiang <bnujli@gmail.com>
Co-authored-by: Davor Runje <davor@airt.ai>
This commit is contained in:
Beibin Li 2024-04-16 17:24:07 -07:00 committed by GitHub
parent f4977e2263
commit 0aaf30a8da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 2064 additions and 2 deletions

View File

@ -256,6 +256,44 @@ jobs:
file: ./coverage.xml
flags: unittests
GeminiTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for Gemini
run: |
pip install -e .[gemini,test]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
coverage run -a -m pytest test/oai/test_gemini.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
ContextHandling:
runs-on: ${{ matrix.os }}
strategy:

4
.gitignore vendored
View File

@ -172,6 +172,10 @@ test/my_tmp/*
# Storage for the AgentEval output
test/test_files/agenteval-in-out/out/
# local cache or coding foler
local_cache/
coding/
# Files created by tests
*tmp_code_*
test/agentchat/test_agent_scripts/*

View File

@ -43,6 +43,7 @@ repos:
website/static/img/ag.svg |
website/yarn.lock |
website/docs/tutorial/code-executors.ipynb |
website/docs/topics/non-openai-models/cloud-gemini.ipynb |
notebook/.*
)$
# See https://jaredkhan.com/blog/mypy-pre-commit

View File

@ -1121,7 +1121,15 @@ class ConversableAgent(LLMAgent):
def _last_msg_as_summary(sender, recipient, summary_args) -> str:
"""Get a chat summary from the last message of the recipient."""
try:
summary = recipient.last_message(sender)["content"].replace("TERMINATE", "")
content = recipient.last_message(sender)["content"]
if isinstance(content, str):
summary = content.replace("TERMINATE", "")
elif isinstance(content, list):
# Remove the `TERMINATE` word in the content list.
summary = [
{**x, "text": x["text"].replace("TERMINATE", "")} if isinstance(x, dict) and "text" in x else x
for x in content
]
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}. Using an empty str as summary.", UserWarning)
summary = ""

View File

@ -42,6 +42,13 @@ else:
TOOL_ENABLED = True
ERROR = None
try:
from autogen.oai.gemini import GeminiClient
gemini_import_exception: Optional[ImportError] = None
except ImportError as e:
gemini_import_exception = e
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@ -425,6 +432,10 @@ class OpenAIWrapper:
self._configure_azure_openai(config, openai_config)
client = AzureOpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
elif api_type is not None and api_type.startswith("google"):
if gemini_import_exception:
raise ImportError("Please install `google-generativeai` to use Google OpenAI API.")
self._clients.append(GeminiClient(**openai_config))
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))

310
autogen/oai/gemini.py Normal file
View File

@ -0,0 +1,310 @@
"""Create a OpenAI-compatible client for Gemini features.
Example:
llm_config={
"config_list": [{
"api_type": "google",
"model": "models/gemini-pro",
"api_key": os.environ.get("GOOGLE_API_KEY")
}
]}
agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
Resources:
- https://ai.google.dev/docs
- https://cloud.google.com/vertex-ai/docs/generative-ai/migrate-from-azure
- https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/
- https://ai.google.dev/api/python/google/generativeai/ChatSession
"""
from __future__ import annotations
import base64
import os
import random
import re
import time
import warnings
from io import BytesIO
from typing import Any, Dict, List, Mapping, Union
import google.generativeai as genai
import requests
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from PIL import Image
class GeminiClient:
"""Client for Google's Gemini API.
Please visit this [page](https://github.com/microsoft/autogen/issues/2387) for the roadmap of Gemini integration
of AutoGen.
"""
def __init__(self, **kwargs):
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
self.api_key = os.getenv("GOOGLE_API_KEY")
assert (
self.api_key
), "Please provide api_key in your config list entry for Gemini or set the GOOGLE_API_KEY env variable."
def message_retrieval(self, response) -> List:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
return [choice.message for choice in response.choices]
def cost(self, response) -> float:
return response.cost
@staticmethod
def get_usage(response) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
# ... # pragma: no cover
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
"cost": response.cost,
"model": response.model,
}
def create(self, params: Dict) -> ChatCompletion:
model_name = params.get("model", "gemini-pro")
if not model_name:
raise ValueError(
"Please provide a model name for the Gemini Client. "
"You can configurate it in the OAI Config List file. "
"See this [LLM configuration tutorial](https://microsoft.github.io/autogen/docs/topics/llm_configuration/) for more details."
)
params.get("api_type", "google") # not used
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
params.get("temperature", 0.5)
params.get("top_p", 1.0)
params.get("max_tokens", 4096)
if stream:
# warn user that streaming is not supported
warnings.warn(
"Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
UserWarning,
)
if n_response > 1:
warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning)
if "vision" not in model_name:
# A. create and call the chat model.
gemini_messages = oai_messages_to_gemini_messages(messages)
# we use chat model by default
model = genai.GenerativeModel(model_name)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
ans = None
try:
response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
f"InternalServerError `500` occurs when calling Gemini's chat model. Retry in {delay} seconds...",
UserWarning,
)
time.sleep(delay)
except Exception as e:
raise RuntimeError(f"Google GenAI exception occurred while calling Gemini API: {e}")
else:
# `ans = response.text` is unstable. Use the following code instead.
ans: str = chat.history[-1].parts[0].text
break
if ans is None:
raise RuntimeError(f"Fail to get response from Google AI after retrying {attempt + 1} times.")
prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens
completion_tokens = model.count_tokens(ans).total_tokens
elif model_name == "gemini-pro-vision":
# B. handle the vision model
# Gemini's vision model does not support chat history yet
model = genai.GenerativeModel(model_name)
genai.configure(api_key=self.api_key)
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
user_message = oai_content_to_gemini_content(messages[-1]["content"])
if len(messages) > 2:
warnings.warn(
"Warning: Gemini's vision model does not support chat history yet.",
"We only use the last message as the prompt.",
UserWarning,
)
response = model.generate_content(user_message, stream=stream)
# ans = response.text
ans: str = response._result.candidates[0].content.parts[0].text
prompt_tokens = model.count_tokens(user_message).total_tokens
completion_tokens = model.count_tokens(ans).total_tokens
# 3. convert output
message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
choices = [Choice(finish_reason="stop", index=0, message=message)]
response_oai = ChatCompletion(
id=str(random.randint(0, 1000)),
model=model_name,
created=int(time.time() * 1000),
object="chat.completion",
choices=choices,
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
)
return response_oai
def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
if "1.5" in model_name or "gemini-experimental" in model_name:
# "gemini-1.5-pro-preview-0409"
# Cost is $7 per million input tokens and $21 per million output tokens
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
# Cost is $0.5 per million input tokens and $1.5 per million output tokens
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6
def oai_content_to_gemini_content(content: Union[str, List]) -> List:
"""Convert content from OAI format to Gemini format"""
rst = []
if isinstance(content, str):
rst.append(Part(text=content))
return rst
assert isinstance(content, list)
for msg in content:
if isinstance(msg, dict):
assert "type" in msg, f"Missing 'type' field in message: {msg}"
if msg["type"] == "text":
rst.append(Part(text=msg["text"]))
elif msg["type"] == "image_url":
b64_img = get_image_data(msg["image_url"]["url"])
img = _to_pil(b64_img)
rst.append(img)
else:
raise ValueError(f"Unsupported message type: {msg['type']}")
else:
raise ValueError(f"Unsupported message type: {type(msg)}")
return rst
def concat_parts(parts: List[Part]) -> List:
"""Concatenate parts with the same type.
If two adjacent parts both have the "text" attribute, then it will be joined into one part.
"""
if not parts:
return []
concatenated_parts = []
previous_part = parts[0]
for current_part in parts[1:]:
if previous_part.text != "":
previous_part.text += current_part.text
else:
concatenated_parts.append(previous_part)
previous_part = current_part
if previous_part.text == "":
previous_part.text = "empty" # Empty content is not allowed.
concatenated_parts.append(previous_part)
return concatenated_parts
def oai_messages_to_gemini_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Gemini format.
Make sure the "user" role and "model" role are interleaved.
Also, make sure the last item is from the "user" role.
"""
prev_role = None
rst = []
curr_parts = []
for i, message in enumerate(messages):
parts = oai_content_to_gemini_content(message["content"])
role = "user" if message["role"] in ["user", "system"] else "model"
if prev_role is None or role == prev_role:
curr_parts += parts
elif role != prev_role:
rst.append(Content(parts=concat_parts(curr_parts), role=prev_role))
curr_parts = parts
prev_role = role
# handle the last message
rst.append(Content(parts=concat_parts(curr_parts), role=role))
# The Gemini is restrict on order of roles, such that
# 1. The messages should be interleaved between user and model.
# 2. The last message must be from the user role.
# We add a dummy message "continue" if the last role is not the user.
if rst[-1].role != "user":
rst.append(Content(parts=oai_content_to_gemini_content("continue"), role="user"))
return rst
def _to_pil(data: str) -> Image.Image:
"""
Converts a base64 encoded image data string to a PIL Image object.
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
and finally creates and returns a PIL Image object from the BytesIO object.
Parameters:
data (str): The base64 encoded image data string.
Returns:
Image.Image: The PIL Image object created from the input data.
"""
return Image.open(BytesIO(base64.b64decode(data)))
def get_image_data(image_file: str, use_b64=True) -> bytes:
if image_file.startswith("http://") or image_file.startswith("https://"):
response = requests.get(image_file)
content = response.content
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
else:
image = Image.open(image_file).convert("RGB")
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
if use_b64:
return base64.b64encode(content).decode("utf-8")
else:
return content

View File

@ -66,7 +66,7 @@ def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613"
elif isinstance(input, list) or isinstance(input, dict):
return _num_token_from_messages(input, model=model)
else:
raise ValueError("input must be str, list or dict")
raise ValueError(f"input must be str, list or dict, but we got {type(input)}")
def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"):
@ -111,6 +111,9 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
elif "gpt-4" in model:
logger.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
elif "gemini" in model:
logger.info("Gemini is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""

View File

@ -64,6 +64,7 @@ setuptools.setup(
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
"graph": ["networkx", "matplotlib"],
"gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"],
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
"redis": ["redis"],
"cosmosdb": ["azure-cosmos>=4.2.0"],

148
test/oai/test_gemini.py Normal file
View File

@ -0,0 +1,148 @@
from unittest.mock import MagicMock, patch
import pytest
try:
from google.api_core.exceptions import InternalServerError
from autogen.oai.gemini import GeminiClient
skip = False
except ImportError:
GeminiClient = object
InternalServerError = object
skip = True
# Fixtures for mock data
@pytest.fixture
def mock_response():
class MockResponse:
def __init__(self, text, choices, usage, cost, model):
self.text = text
self.choices = choices
self.usage = usage
self.cost = cost
self.model = model
return MockResponse
@pytest.fixture
def gemini_client():
return GeminiClient(api_key="fake_api_key")
# Test initialization and configuration
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_initialization():
with pytest.raises(AssertionError):
GeminiClient() # Should raise an AssertionError due to missing API key
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_valid_initialization(gemini_client):
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"
# Test error handling
@patch("autogen.oai.gemini.genai")
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_internal_server_error_retry(mock_genai, gemini_client):
mock_genai.GenerativeModel.side_effect = [InternalServerError("Test Error"), None] # First call fails
# Mock successful response
mock_chat = MagicMock()
mock_chat.send_message.return_value = "Successful response"
mock_genai.GenerativeModel.return_value.start_chat.return_value = mock_chat
with patch.object(gemini_client, "create", return_value="Retried Successfully"):
response = gemini_client.create({"model": "gemini-pro", "messages": [{"content": "Hello"}]})
assert response == "Retried Successfully", "Should retry on InternalServerError"
# Test cost calculation
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_cost_calculation(gemini_client, mock_response):
response = mock_response(
text="Example response",
choices=[{"message": "Test message 1"}],
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
cost=0.01,
model="gemini-pro",
)
assert gemini_client.cost(response) > 0, "Cost should be correctly calculated as zero"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.genai.GenerativeModel")
@patch("autogen.oai.gemini.genai.configure")
def test_create_response(mock_configure, mock_generative_model, gemini_client):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
mock_configure.return_value = None
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat
# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part
# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])
# Call the create method
response = gemini_client.create(
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
# Assertions to check if response is structured as expected
assert response.choices[0].message.content == "Example response", "Response content should match expected output"
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.genai.GenerativeModel")
@patch("autogen.oai.gemini.genai.configure")
def test_create_vision_model_response(mock_configure, mock_generative_model, gemini_client):
# Mock the genai model configuration and creation process
mock_model = MagicMock()
mock_configure.return_value = None
mock_generative_model.return_value = mock_model
# Set up a mock to simulate the vision model behavior
mock_vision_response = MagicMock()
mock_vision_part = MagicMock(text="Vision model output")
# Setting up the chain of return values for vision model response
mock_vision_response._result.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = (
mock_vision_part
)
mock_model.generate_content.return_value = mock_vision_response
# Call the create method with vision model parameters
response = gemini_client.create(
{
"model": "gemini-pro-vision", # Vision model name
"messages": [
{
"content": [
{"type": "text", "text": "Let's play a game."},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
},
},
],
"role": "user",
}
], # Assuming a simple content input for vision
"stream": False,
}
)
# Assertions to check if response is structured as expected
assert (
response.choices[0].message.content == "Vision model output"
), "Response content should match expected output from vision model"

File diff suppressed because one or more lines are too long