Large Multimodal Models in AgentChat (#554)

* LMM Code added

* LLaVA notebook update

* Test cases and Notebook modified for OpenAI v1

* Move LMM into contrib
To resolve test issues and deploy issues
In the future, we can install pillow by default, and then move back
LMM agents into agentchat

* LMM test setup update

* try...except... clause for LMM tests

* disable patch for llava agent test
To resolve dependencies issue for build

* Add LMM Blog

* Change docstring for LMM agents

* Docstring update patch

* llava: insert reply at position 1 now
So, it can still handle human_input_mode
and max_consecutive_reply

* Resolve comments
Fixing: typos, blogs, yml, and add OpenAIWrapper

* Signature typo fix for LMM agent: system_message

* Update LMM "content" from latest OpenAI release
Reference  https://platform.openai.com/docs/guides/vision

* update LMM test according to latest OpenAI release

* Fully support GPT-4V now
1. Add a notebook for GPT-4V. LLava notebook also updated.
2. img_utils updated
3. GPT-4V formatter now return base64 image with mime type
4. Infer mime type directly from b64 image content (while loading
   without suffix)
5. Test cases modified according to all the related changes.

* GPT-4V link updated in blog

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Beibin Li 2023-11-06 13:33:51 -08:00 committed by GitHub
parent 306ac4d7f7
commit b41b366549
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 2116 additions and 732 deletions

60
.github/workflows/contrib-lmm.yml vendored Normal file
View File

@ -0,0 +1,60 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: ContribTests
on:
pull_request:
branches: ['main', 'dev/v0.2']
paths:
- 'autogen/img_utils.py'
- 'autogen/agentchat/contrib/multimodal_conversable_agent.py'
- 'autogen/agentchat/contrib/llava_agent.py'
- 'test/test_img_utils.py'
- 'test/agentchat/contrib/test_lmm.py'
- 'test/agentchat/contrib/test_llava.py'
- '.github/workflows/lmm-test.yml'
- 'setup.py'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
LMMTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for LMM
run: |
pip install -e .[lmm]
pip uninstall -y openai
- name: Test LMM and LLaVA
run: |
pytest test/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.10'
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests

View File

@ -1,8 +1,8 @@
from .agent import Agent
from .conversable_agent import ConversableAgent
from .assistant_agent import AssistantAgent
from .user_proxy_agent import UserProxyAgent
from .conversable_agent import ConversableAgent
from .groupchat import GroupChat, GroupChatManager
from .user_proxy_agent import UserProxyAgent
__all__ = [
"Agent",

View File

@ -0,0 +1,178 @@
import json
import logging
import os
import pdb
import re
from typing import Any, Dict, List, Optional, Tuple, Union
import replicate
import requests
from regex import R
from autogen.agentchat.agent import Agent
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.code_utils import content_str
from autogen.img_utils import get_image_data, llava_formater
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
logger = logging.getLogger(__name__)
# we will override the following variables later.
SEP = "###"
DEFAULT_LLAVA_SYS_MSG = "You are an AI agent and you can view images."
class LLaVAAgent(MultimodalConversableAgent):
def __init__(
self,
name: str,
system_message: Optional[Tuple[str, List]] = DEFAULT_LLAVA_SYS_MSG,
*args,
**kwargs,
):
"""
Args:
name (str): agent name.
system_message (str): system message for the ChatCompletion inference.
Please override this attribute if you want to reprogram the agent.
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](../conversable_agent#__init__).
"""
super().__init__(
name,
system_message=system_message,
*args,
**kwargs,
)
assert self.llm_config is not None, "llm_config must be provided."
self.register_reply([Agent, None], reply_func=LLaVAAgent._image_reply, position=1)
def _image_reply(self, messages=None, sender=None, config=None):
# Note: we did not use "llm_config" yet.
if all((messages is None, sender is None)):
error_msg = f"Either {messages=} or {sender=} must be provided."
logger.error(error_msg)
raise AssertionError(error_msg)
if messages is None:
messages = self._oai_messages[sender]
# The formats for LLaVA and GPT are different. So, we manually handle them here.
images = []
prompt = content_str(self.system_message) + "\n"
for msg in messages:
role = "Human" if msg["role"] == "user" else "Assistant"
# pdb.set_trace()
images += [d["image_url"]["url"] for d in msg["content"] if d["type"] == "image_url"]
content_prompt = content_str(msg["content"])
prompt += f"{SEP}{role}: {content_prompt}\n"
prompt += "\n" + SEP + "Assistant: "
images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images]
print(colored(prompt, "blue"))
out = ""
retry = 10
while len(out) == 0 and retry > 0:
# image names will be inferred automatically from llava_call
out = llava_call_binary(
prompt=prompt,
images=images,
config_list=self.llm_config["config_list"],
temperature=self.llm_config.get("temperature", 0.5),
max_new_tokens=self.llm_config.get("max_new_tokens", 2000),
)
retry -= 1
assert out != "", "Empty response from LLaVA."
return True, out
def _llava_call_binary_with_config(
prompt: str, images: list, config: dict, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1
):
if config["base_url"].find("0.0.0.0") >= 0 or config["base_url"].find("localhost") >= 0:
llava_mode = "local"
else:
llava_mode = "remote"
if llava_mode == "local":
headers = {"User-Agent": "LLaVA Client"}
pload = {
"model": config["model"],
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"stop": SEP,
"images": images,
}
response = requests.post(
config["base_url"].rstrip("/") + "/worker_generate_stream", headers=headers, json=pload, stream=False
)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(SEP)[-1]
elif llava_mode == "remote":
# The Replicate version of the model only support 1 image for now.
img = "data:image/jpeg;base64," + images[0]
response = replicate.run(
config["base_url"], input={"image": img, "prompt": prompt.replace("<image>", " "), "seed": seed}
)
# The yorickvp/llava-13b model can stream output as it's running.
# The predict method returns an iterator, and you can iterate over that output.
output = ""
for item in response:
# https://replicate.com/yorickvp/llava-13b/versions/2facb4a474a0462c15041b78b1ad70952ea46b5ec6ad29583c0b29dbd4249591/api#output-schema
output += item
# Remove the prompt and the space.
output = output.replace(prompt, "").strip().rstrip()
return output
def llava_call_binary(
prompt: str, images: list, config_list: list, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1
):
# TODO 1: add caching around the LLaVA call to save compute and cost
# TODO 2: add `seed` to ensure reproducibility. The seed is not working now.
for config in config_list:
try:
return _llava_call_binary_with_config(prompt, images, config, max_new_tokens, temperature, seed)
except Exception as e:
print(f"Error: {e}")
continue
def llava_call(prompt: str, llm_config: dict) -> str:
"""
Makes a call to the LLaVA service to generate text based on a given prompt
"""
prompt, images = llava_formater(prompt, order_image_tokens=False)
for im in images:
if len(im) == 0:
raise RuntimeError("An image is empty!")
return llava_call_binary(
prompt,
images,
config_list=llm_config["config_list"],
max_new_tokens=llm_config.get("max_new_tokens", 2000),
temperature=llm_config.get("temperature", 0.5),
seed=llm_config.get("seed", None),
)

View File

@ -0,0 +1,107 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from autogen import OpenAIWrapper
from autogen.agentchat import Agent, ConversableAgent
from autogen.img_utils import gpt4v_formatter
try:
from termcolor import colored
except ImportError:
def colored(x, *args, **kwargs):
return x
from autogen.code_utils import content_str
DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
class MultimodalConversableAgent(ConversableAgent):
def __init__(
self,
name: str,
system_message: Optional[Union[str, List]] = DEFAULT_LMM_SYS_MSG,
is_termination_msg: str = None,
*args,
**kwargs,
):
"""
Args:
name (str): agent name.
system_message (str): system message for the OpenAIWrapper inference.
Please override this attribute if you want to reprogram the agent.
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](../conversable_agent#__init__).
"""
super().__init__(
name,
system_message,
is_termination_msg=is_termination_msg,
*args,
**kwargs,
)
self.update_system_message(system_message)
self._is_termination_msg = (
is_termination_msg
if is_termination_msg is not None
else (lambda x: any([item["text"] == "TERMINATE" for item in x.get("content") if item["type"] == "text"]))
)
@property
def system_message(self) -> List:
"""Return the system message."""
return self._oai_system_message[0]["content"]
def update_system_message(self, system_message: Union[Dict, List, str]):
"""Update the system message.
Args:
system_message (str): system message for the OpenAIWrapper inference.
"""
self._oai_system_message[0]["content"] = self._message_to_dict(system_message)["content"]
self._oai_system_message[0]["role"] = "system"
@staticmethod
def _message_to_dict(message: Union[Dict, List, str]):
"""Convert a message to a dictionary.
The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary.
"""
if isinstance(message, str):
return {"content": gpt4v_formatter(message)}
if isinstance(message, list):
return {"content": message}
else:
return message
def _print_received_message(self, message: Union[Dict, str], sender: Agent):
# print the message received
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
if message.get("role") == "function":
func_print = f"***** Response from calling function \"{message['name']}\" *****"
print(colored(func_print, "green"), flush=True)
print(content_str(message["content"]), flush=True)
print(colored("*" * len(func_print), "green"), flush=True)
else:
content = message.get("content")
if content is not None:
if "context" in message:
content = OpenAIWrapper.instantiate(
content,
message["context"],
self.llm_config and self.llm_config.get("allow_format_str_template", False),
)
print(content_str(content), flush=True)
if "function_call" in message:
func_print = f"***** Suggested function Call: {message['function_call'].get('name', '(No function name found)')} *****"
print(colored(func_print, "green"), flush=True)
print(
"Arguments: \n",
message["function_call"].get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
print(colored("*" * len(func_print), "green"), flush=True)
print("\n", "-" * 80, flush=True, sep="")

View File

@ -1,14 +1,15 @@
import subprocess
import sys
import logging
import os
import pathlib
from typing import List, Dict, Tuple, Optional, Union, Callable
import re
import subprocess
import sys
import time
from hashlib import md5
import logging
from autogen import oai
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from hashlib import md5
from typing import Callable, Dict, List, Optional, Tuple, Union
from autogen import oai
try:
import docker
@ -29,6 +30,19 @@ PATH_SEPARATOR = WIN32 and "\\" or "/"
logger = logging.getLogger(__name__)
def content_str(content: Union[str, List]) -> str:
if type(content) is str:
return content
rst = ""
for item in content:
if item["type"] == "text":
rst += item["text"]
else:
assert isinstance(item, dict) and item["type"] == "image_url", "Wrong content format."
rst += "<image>"
return rst
def infer_lang(code):
"""infer the language for the code.
TODO: make it robust.
@ -46,12 +60,13 @@ def infer_lang(code):
def extract_code(
text: str, pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
text: Union[str, List], pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
) -> List[Tuple[str, str]]:
"""Extract code from a text.
Args:
text (str): The text to extract code from.
text (str or List): The content to extract code from. The content can be
a string or a list, as returned by standard GPT or multimodal GPT.
pattern (str, optional): The regular expression pattern for finding the
code block. Defaults to CODE_BLOCK_PATTERN.
detect_single_line_code (bool, optional): Enable the new feature for
@ -62,6 +77,7 @@ def extract_code(
If there is no code block in the input text, the language would be "unknown".
If there is code block but the language is not specified, the language would be "".
"""
text = content_str(text)
if not detect_single_line_code:
match = re.findall(pattern, text, flags=re.DOTALL)
return match if match else [(UNKNOWN, text)]

170
autogen/img_utils.py Normal file
View File

@ -0,0 +1,170 @@
import base64
import mimetypes
import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
import requests
from PIL import Image
def get_image_data(image_file: str, use_b64=True) -> bytes:
if image_file.startswith("http://") or image_file.startswith("https://"):
response = requests.get(image_file)
content = response.content
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
else:
image = Image.open(image_file).convert("RGB")
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
if use_b64:
return base64.b64encode(content).decode("utf-8")
else:
return content
def llava_formater(prompt: str, order_image_tokens: bool = False) -> Tuple[str, List[str]]:
"""
Formats the input prompt by replacing image tags and returns the new prompt along with image locations.
Parameters:
- prompt (str): The input string that may contain image tags like <img ...>.
- order_image_tokens (bool, optional): Whether to order the image tokens with numbers.
It will be useful for GPT-4V. Defaults to False.
Returns:
- Tuple[str, List[str]]: A tuple containing the formatted string and a list of images (loaded in b64 format).
"""
# Initialize variables
new_prompt = prompt
image_locations = []
images = []
image_count = 0
# Regular expression pattern for matching <img ...> tags
img_tag_pattern = re.compile(r"<img ([^>]+)>")
# Find all image tags
for match in img_tag_pattern.finditer(prompt):
image_location = match.group(1)
try:
img_data = get_image_data(image_location)
except Exception as e:
# Remove the token
print(f"Warning! Unable to load image from {image_location}, because of {e}")
new_prompt = new_prompt.replace(match.group(0), "", 1)
continue
image_locations.append(image_location)
images.append(img_data)
# Increment the image count and replace the tag in the prompt
new_token = f"<image {image_count}>" if order_image_tokens else "<image>"
new_prompt = new_prompt.replace(match.group(0), new_token, 1)
image_count += 1
return new_prompt, images
def convert_base64_to_data_uri(base64_image):
def _get_mime_type_from_data_uri(base64_image):
# Decode the base64 string
image_data = base64.b64decode(base64_image)
# Check the first few bytes for known signatures
if image_data.startswith(b"\xff\xd8\xff"):
return "image/jpeg"
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
return "image/png"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
return "image/gif"
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
return "image/webp"
return "image/jpeg" # use jpeg for unknown formats, best guess.
mime_type = _get_mime_type_from_data_uri(base64_image)
data_uri = f"data:{mime_type};base64,{base64_image}"
return data_uri
def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
"""
Formats the input prompt by replacing image tags and returns a list of text and images.
Parameters:
- prompt (str): The input string that may contain image tags like <img ...>.
Returns:
- List[Union[str, dict]]: A list of alternating text and image dictionary items.
"""
output = []
last_index = 0
image_count = 0
# Regular expression pattern for matching <img ...> tags
img_tag_pattern = re.compile(r"<img ([^>]+)>")
# Find all image tags
for match in img_tag_pattern.finditer(prompt):
image_location = match.group(1)
try:
img_data = get_image_data(image_location)
except Exception as e:
# Warning and skip this token
print(f"Warning! Unable to load image from {image_location}, because {e}")
continue
# Add text before this image tag to output list
output.append({"type": "text", "text": prompt[last_index : match.start()]})
# Add image data to output list
output.append({"type": "image_url", "image_url": {"url": convert_base64_to_data_uri(img_data)}})
last_index = match.end()
image_count += 1
# Add remaining text to output list
output.append({"type": "text", "text": prompt[last_index:]})
return output
def extract_img_paths(paragraph: str) -> list:
"""
Extract image paths (URLs or local paths) from a text paragraph.
Parameters:
paragraph (str): The input text paragraph.
Returns:
list: A list of extracted image paths.
"""
# Regular expression to match image URLs and file paths
img_path_pattern = re.compile(
r"\b(?:http[s]?://\S+\.(?:jpg|jpeg|png|gif|bmp)|\S+\.(?:jpg|jpeg|png|gif|bmp))\b", re.IGNORECASE
)
# Find all matches in the paragraph
img_paths = re.findall(img_path_pattern, paragraph)
return img_paths
def _to_pil(data: str) -> Image.Image:
"""
Converts a base64 encoded image data string to a PIL Image object.
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
and finally creates and returns a PIL Image object from the BytesIO object.
Parameters:
data (str): The base64 encoded image data string.
Returns:
Image.Image: The PIL Image object created from the input data.
"""
return Image.open(BytesIO(base64.b64decode(data)))

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,12 +1,12 @@
import setuptools
import os
import setuptools
here = os.path.abspath(os.path.dirname(__file__))
with open("README.md", "r", encoding="UTF-8") as fh:
long_description = fh.read()
# Get the code version
version = {}
with open(os.path.join(here, "autogen/version.py")) as fp:
@ -22,7 +22,6 @@ install_requires = [
"tiktoken",
]
setuptools.setup(
name="pyautogen",
version=__version__,
@ -52,6 +51,7 @@ setuptools.setup(
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
"retrievechat": ["chromadb", "sentence_transformers", "pypdf", "ipython"],
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
},
classifiers=[
"Programming Language :: Python :: 3",

View File

@ -0,0 +1,129 @@
import unittest
from unittest.mock import MagicMock, patch
import pytest
import autogen
try:
from autogen.agentchat.contrib.llava_agent import (
LLaVAAgent,
_llava_call_binary_with_config,
llava_call,
llava_call_binary,
)
except ImportError:
skip = True
else:
skip = False
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestLLaVAAgent(unittest.TestCase):
def setUp(self):
self.agent = LLaVAAgent(
name="TestAgent",
llm_config={
"timeout": 600,
"seed": 42,
"config_list": [{"model": "llava-fake", "base_url": "localhost:8000", "api_key": "Fake"}],
},
)
def test_init(self):
self.assertIsInstance(self.agent, LLaVAAgent)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestLLavaCallBinaryWithConfig(unittest.TestCase):
@patch("requests.post")
def test_local_mode(self, mock_post):
# Mocking the response of requests.post
mock_response = MagicMock()
mock_response.iter_lines.return_value = [b'{"text":"response text"}']
mock_post.return_value = mock_response
# Calling the function
output = _llava_call_binary_with_config(
prompt="Test Prompt",
images=[],
config={"base_url": "http://0.0.0.0/api", "model": "test-model"},
max_new_tokens=1000,
temperature=0.5,
seed=1,
)
# Verifying the results
self.assertEqual(output, "response text")
mock_post.assert_called_once_with(
"http://0.0.0.0/api/worker_generate_stream",
headers={"User-Agent": "LLaVA Client"},
json={
"model": "test-model",
"prompt": "Test Prompt",
"max_new_tokens": 1000,
"temperature": 0.5,
"stop": "###",
"images": [],
},
stream=False,
)
@patch("replicate.run")
def test_remote_mode(self, mock_run):
# Mocking the response of replicate.run
mock_run.return_value = iter(["response ", "text"])
# Calling the function
output = _llava_call_binary_with_config(
prompt="Test Prompt",
images=["image_data"],
config={"base_url": "http://remote/api", "model": "test-model"},
max_new_tokens=1000,
temperature=0.5,
seed=1,
)
# Verifying the results
self.assertEqual(output, "response text")
mock_run.assert_called_once_with(
"http://remote/api",
input={"image": "_data", "prompt": "Test Prompt", "seed": 1},
)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestLLavaCall(unittest.TestCase):
@patch("autogen.agentchat.contrib.llava_agent.llava_formater")
@patch("autogen.agentchat.contrib.llava_agent.llava_call_binary")
def test_llava_call(self, mock_llava_call_binary, mock_llava_formater):
# Set up the mocks
mock_llava_formater.return_value = ("formatted prompt", ["image1", "image2"])
mock_llava_call_binary.return_value = "Generated Text"
# Set up the llm_config dictionary
llm_config = {
"config_list": [{"api_key": "value", "base_url": "localhost:8000"}],
"max_new_tokens": 2000,
"temperature": 0.5,
"seed": 1,
}
# Call the function
result = llava_call("Test Prompt", llm_config)
# Check the results
mock_llava_formater.assert_called_once_with("Test Prompt", order_image_tokens=False)
mock_llava_call_binary.assert_called_once_with(
"formatted prompt",
["image1", "image2"],
config_list=llm_config["config_list"],
max_new_tokens=2000,
temperature=0.5,
seed=1,
)
self.assertEqual(result, "Generated Text")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,83 @@
import unittest
from unittest.mock import MagicMock
import pytest
import autogen
from autogen.agentchat.agent import Agent
try:
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
except ImportError:
skip = True
else:
skip = False
base64_encoded_image = (
""
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestMultimodalConversableAgent(unittest.TestCase):
def setUp(self):
self.agent = MultimodalConversableAgent(
name="TestAgent",
llm_config={
"timeout": 600,
"seed": 42,
"config_list": [{"model": "gpt-4-vision-preview", "api_key": "sk-fake"}],
},
)
def test_system_message(self):
# Test default system message
self.assertEqual(
self.agent.system_message,
[
{
"type": "text",
"text": "You are a helpful AI assistant.",
}
],
)
# Test updating system message
new_message = f"We will discuss <img {base64_encoded_image}> in this conversation."
self.agent.update_system_message(new_message)
self.assertEqual(
self.agent.system_message,
[
{"type": "text", "text": "We will discuss "},
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
{"type": "text", "text": " in this conversation."},
],
)
def test_message_to_dict(self):
# Test string message
message_str = "Hello"
expected_dict = {"content": [{"type": "text", "text": "Hello"}]}
self.assertDictEqual(self.agent._message_to_dict(message_str), expected_dict)
# Test list message
message_list = [{"type": "text", "text": "Hello"}]
expected_dict = {"content": message_list}
self.assertDictEqual(self.agent._message_to_dict(message_list), expected_dict)
# Test dictionary message
message_dict = {"content": [{"type": "text", "text": "Hello"}]}
self.assertDictEqual(self.agent._message_to_dict(message_dict), message_dict)
def test_print_received_message(self):
sender = Agent(name="SenderAgent")
message_str = "Hello"
self.agent._print_received_message = MagicMock() # Mocking print method to avoid actual print
self.agent._print_received_message(message_str, sender)
self.agent._print_received_message.assert_called_with(message_str, sender)
if __name__ == "__main__":
unittest.main()

View File

@ -1,16 +1,20 @@
import sys
import os
import sys
import unittest
import pytest
import autogen
from autogen.code_utils import (
PATH_SEPARATOR,
UNKNOWN,
extract_code,
WIN32,
content_str,
execute_code,
infer_lang,
extract_code,
improve_code,
improve_function,
PATH_SEPARATOR,
WIN32,
infer_lang,
)
KEY_LOC = "notebook"
@ -315,8 +319,36 @@ def _test_improve():
f.write(improvement)
class TestContentStr(unittest.TestCase):
def test_string_content(self):
self.assertEqual(content_str("simple string"), "simple string")
def test_list_of_text_content(self):
content = [{"type": "text", "text": "hello"}, {"type": "text", "text": " world"}]
self.assertEqual(content_str(content), "hello world")
def test_mixed_content(self):
content = [{"type": "text", "text": "hello"}, {"type": "image_url", "url": "http://example.com/image.png"}]
self.assertEqual(content_str(content), "hello<image>")
def test_invalid_content(self):
content = [{"type": "text", "text": "hello"}, {"type": "wrong_type", "url": "http://example.com/image.png"}]
with self.assertRaises(AssertionError) as context:
content_str(content)
self.assertIn("Wrong content format", str(context.exception))
def test_empty_list(self):
self.assertEqual(content_str([]), "")
def test_non_dict_in_list(self):
content = ["string", {"type": "text", "text": "text"}]
with self.assertRaises(TypeError):
content_str(content)
if __name__ == "__main__":
# test_infer_lang()
# test_extract_code()
test_execute_code()
# test_find_code()
unittest.main()

193
test/test_img_utils.py Normal file
View File

@ -0,0 +1,193 @@
import base64
import os
import pdb
import unittest
from unittest.mock import patch
import pytest
import requests
try:
from PIL import Image
from autogen.img_utils import extract_img_paths, get_image_data, gpt4v_formatter, llava_formater
except ImportError:
skip = True
else:
skip = False
base64_encoded_image = (
""
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
)
raw_encoded_image = (
"iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestGetImageData(unittest.TestCase):
def test_http_image(self):
with patch("requests.get") as mock_get:
mock_response = requests.Response()
mock_response.status_code = 200
mock_response._content = b"fake image content"
mock_get.return_value = mock_response
result = get_image_data("http://example.com/image.png")
self.assertEqual(result, base64.b64encode(b"fake image content").decode("utf-8"))
def test_base64_encoded_image(self):
result = get_image_data(base64_encoded_image)
self.assertEqual(result, base64_encoded_image.split(",", 1)[1])
def test_local_image(self):
# Create a temporary file to simulate a local image file.
temp_file = "_temp.png"
image = Image.new("RGB", (60, 30), color=(73, 109, 137))
image.save(temp_file)
result = get_image_data(temp_file)
with open(temp_file, "rb") as temp_image_file:
temp_image_file.seek(0)
expected_content = base64.b64encode(temp_image_file.read()).decode("utf-8")
self.assertEqual(result, expected_content)
os.remove(temp_file)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestLlavaFormater(unittest.TestCase):
def test_no_images(self):
"""
Test the llava_formater function with a prompt containing no images.
"""
prompt = "This is a test."
expected_output = (prompt, [])
result = llava_formater(prompt)
self.assertEqual(result, expected_output)
@patch("autogen.img_utils.get_image_data")
def test_with_images(self, mock_get_image_data):
"""
Test the llava_formater function with a prompt containing images.
"""
# Mock the get_image_data function to return a fixed string.
mock_get_image_data.return_value = raw_encoded_image
prompt = "This is a test with an image <img http://example.com/image.png>."
expected_output = ("This is a test with an image <image>.", [raw_encoded_image])
result = llava_formater(prompt)
self.assertEqual(result, expected_output)
@patch("autogen.img_utils.get_image_data")
def test_with_ordered_images(self, mock_get_image_data):
"""
Test the llava_formater function with ordered image tokens.
"""
# Mock the get_image_data function to return a fixed string.
mock_get_image_data.return_value = raw_encoded_image
prompt = "This is a test with an image <img http://example.com/image.png>."
expected_output = ("This is a test with an image <image 0>.", [raw_encoded_image])
result = llava_formater(prompt, order_image_tokens=True)
self.assertEqual(result, expected_output)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestGpt4vFormatter(unittest.TestCase):
def test_no_images(self):
"""
Test the gpt4v_formatter function with a prompt containing no images.
"""
prompt = "This is a test."
expected_output = [{"type": "text", "text": prompt}]
result = gpt4v_formatter(prompt)
self.assertEqual(result, expected_output)
@patch("autogen.img_utils.get_image_data")
def test_with_images(self, mock_get_image_data):
"""
Test the gpt4v_formatter function with a prompt containing images.
"""
# Mock the get_image_data function to return a fixed string.
mock_get_image_data.return_value = raw_encoded_image
prompt = "This is a test with an image <img http://example.com/image.png>."
expected_output = [
{"type": "text", "text": "This is a test with an image "},
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
{"type": "text", "text": "."},
]
result = gpt4v_formatter(prompt)
self.assertEqual(result, expected_output)
@patch("autogen.img_utils.get_image_data")
def test_multiple_images(self, mock_get_image_data):
"""
Test the gpt4v_formatter function with a prompt containing multiple images.
"""
# Mock the get_image_data function to return a fixed string.
mock_get_image_data.return_value = raw_encoded_image
prompt = (
"This is a test with images <img http://example.com/image1.png> and <img http://example.com/image2.png>."
)
expected_output = [
{"type": "text", "text": "This is a test with images "},
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
{"type": "text", "text": " and "},
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
{"type": "text", "text": "."},
]
result = gpt4v_formatter(prompt)
self.assertEqual(result, expected_output)
@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestExtractImgPaths(unittest.TestCase):
def test_no_images(self):
"""
Test the extract_img_paths function with a paragraph containing no images.
"""
paragraph = "This is a test paragraph with no images."
expected_output = []
result = extract_img_paths(paragraph)
self.assertEqual(result, expected_output)
def test_with_images(self):
"""
Test the extract_img_paths function with a paragraph containing images.
"""
paragraph = (
"This is a test paragraph with images http://example.com/image1.jpg and http://example.com/image2.png."
)
expected_output = ["http://example.com/image1.jpg", "http://example.com/image2.png"]
result = extract_img_paths(paragraph)
self.assertEqual(result, expected_output)
def test_mixed_case(self):
"""
Test the extract_img_paths function with mixed case image extensions.
"""
paragraph = "Mixed case extensions http://example.com/image.JPG and http://example.com/image.Png."
expected_output = ["http://example.com/image.JPG", "http://example.com/image.Png"]
result = extract_img_paths(paragraph)
self.assertEqual(result, expected_output)
def test_local_paths(self):
"""
Test the extract_img_paths function with local file paths.
"""
paragraph = "Local paths image1.jpeg and image2.GIF."
expected_output = ["image1.jpeg", "image2.GIF"]
result = extract_img_paths(paragraph)
self.assertEqual(result, expected_output)
if __name__ == "__main__":
unittest.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

View File

@ -0,0 +1,77 @@
---
title: Multimodal with GPT-4V and LLaVA
authors: beibinli
tags: [LMM, multimodal]
---
![LMM Teaser](img/teaser.png)
**In Brief:**
* Introducing the **Multimodal Conversable Agent** and the **LLaVA Agent** to enhance LMM functionalities.
* Users can input text and images simultaneously using the `<img img_path>` tag to specify image loading.
* Demonstrated through the [GPT-4V notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_gpt-4v.ipynb).
* Demonstrated through the [LLaVA notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb).
## Introduction
Large multimodal models (LMMs) augment large language models (LLMs) with the ability to process multi-sensory data.
This blog post and the latest AutoGen update concentrate on visual comprehension. Users can input images, pose questions about them, and receive text-based responses from these LMMs.
We support the `gpt-4-vision-preview` model from OpenAI and `LLaVA` model from Microsoft now.
Here, we emphasize the **Multimodal Conversable Agent** and the **LLaVA Agent** due to their growing popularity.
GPT-4V represents the forefront in image comprehension, while LLaVA is an efficient model, fine-tuned from LLama-2.
## Installation
Incorporate the `lmm` feature during AutoGen installation:
```bash
pip install "pyautogen[lmm]"
```
Subsequently, import the **Multimodal Conversable Agent** or **LLaVA Agent** from AutoGen:
```python
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent # for GPT-4V
from autogen.agentchat.contrib.llava_agent import LLaVAAgent # for LLaVA
```
## Usage
A simple syntax has been defined to incorporate both messages and images within a single string.
Example of an in-context learning prompt:
```python
prompt = """You are now an image classifier for facial expressions. Here are
some examples.
<img happy.jpg> depicts a happy expression.
<img http://some_location.com/sad.jpg> represents a sad expression.
<img obama.jpg> portrays a neutral expression.
Now, identify the facial expression of this individual: <img unknown.png>
"""
agent = MultimodalConversableAgent()
user = UserProxyAgent()
user.initiate_chat(agent, message=prompt)
```
The `MultimodalConversableAgent` interprets the input prompt, extracting images from local or internet sources.
## Advanced Usage
Similar to other AutoGen agents, multimodal agents support multi-round dialogues with other agents, code generation, factual queries, and management via a GroupChat interface.
For example, the `FigureCreator` in our [GPT-4V notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_gpt-4v.ipynb) and [LLaVA notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb) integrates two agents: a coder (an AssistantAgent) and critics (a multimodal agent).
The coder drafts Python code for visualizations, while the critics provide insights for enhancement. Collaboratively, these agents aim to refine visual outputs.
With `human_input_mode=ALWAYS`, you can also contribute suggestions for better visualizations.
## Reference
- [GPT-4V System Card](https://openai.com/research/gpt-4v-system-card)
- [LLaVA GitHub](https://github.com/haotian-liu/LLaVA)
## Future Enhancements
For further inquiries or suggestions, please open an issue in the [AutoGen repository](https://github.com/microsoft/autogen/) or contact me directly at beibin.li@microsoft.com.
AutoGen will continue to evolve, incorporating more multimodal functionalities such as DALLE model integration, audio interaction, and video comprehension. Stay tuned for these exciting developments.

View File

@ -33,3 +33,9 @@ rickyloynd-microsoft:
title: Senior Research Engineer at Microsoft
url: https://github.com/rickyloynd-microsoft
image_url: https://github.com/rickyloynd-microsoft.png
beibinli:
name: Beibin Li
title: Senior Research Engineer at Microsoft
url: https://github.com/beibinli
image_url: https://github.com/beibinli.png

View File

@ -115,6 +115,18 @@ Example notebooks:
[Automated Code Generation and Question Answering with Qdrant based Retrieval Augmented Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_qdrant_RetrieveChat.ipynb)
- #### Large Multimodal Model (LMM) Agents
We offered Multimodal Conversable Agent and LLaVA Agent. Please install with the [lmm] option to use it.
```bash
pip install "pyautogen[lmm]"
```
Example notebooks:
[LLaVA Agent](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb)
- #### mathchat
`pyautogen<0.2` offers an experimental agent for math problem solving. Please install with the [mathchat] option to use it.