mirror of https://github.com/microsoft/autogen.git
Large Multimodal Models in AgentChat (#554)
* LMM Code added * LLaVA notebook update * Test cases and Notebook modified for OpenAI v1 * Move LMM into contrib To resolve test issues and deploy issues In the future, we can install pillow by default, and then move back LMM agents into agentchat * LMM test setup update * try...except... clause for LMM tests * disable patch for llava agent test To resolve dependencies issue for build * Add LMM Blog * Change docstring for LMM agents * Docstring update patch * llava: insert reply at position 1 now So, it can still handle human_input_mode and max_consecutive_reply * Resolve comments Fixing: typos, blogs, yml, and add OpenAIWrapper * Signature typo fix for LMM agent: system_message * Update LMM "content" from latest OpenAI release Reference https://platform.openai.com/docs/guides/vision * update LMM test according to latest OpenAI release * Fully support GPT-4V now 1. Add a notebook for GPT-4V. LLava notebook also updated. 2. img_utils updated 3. GPT-4V formatter now return base64 image with mime type 4. Infer mime type directly from b64 image content (while loading without suffix) 5. Test cases modified according to all the related changes. * GPT-4V link updated in blog --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
306ac4d7f7
commit
b41b366549
|
@ -0,0 +1,60 @@
|
|||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: ContribTests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['main', 'dev/v0.2']
|
||||
paths:
|
||||
- 'autogen/img_utils.py'
|
||||
- 'autogen/agentchat/contrib/multimodal_conversable_agent.py'
|
||||
- 'autogen/agentchat/contrib/llava_agent.py'
|
||||
- 'test/test_img_utils.py'
|
||||
- 'test/agentchat/contrib/test_lmm.py'
|
||||
- 'test/agentchat/contrib/test_llava.py'
|
||||
- '.github/workflows/lmm-test.yml'
|
||||
- 'setup.py'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
LMMTest:
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install packages and dependencies for all tests
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
pip install pytest
|
||||
- name: Install packages and dependencies for LMM
|
||||
run: |
|
||||
pip install -e .[lmm]
|
||||
pip uninstall -y openai
|
||||
- name: Test LMM and LLaVA
|
||||
run: |
|
||||
pytest test/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
|
||||
- name: Coverage
|
||||
if: matrix.python-version == '3.10'
|
||||
run: |
|
||||
pip install coverage>=5.3
|
||||
coverage run -a -m pytest test/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
|
||||
coverage xml
|
||||
- name: Upload coverage to Codecov
|
||||
if: matrix.python-version == '3.10'
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
|
@ -1,8 +1,8 @@
|
|||
from .agent import Agent
|
||||
from .conversable_agent import ConversableAgent
|
||||
from .assistant_agent import AssistantAgent
|
||||
from .user_proxy_agent import UserProxyAgent
|
||||
from .conversable_agent import ConversableAgent
|
||||
from .groupchat import GroupChat, GroupChatManager
|
||||
from .user_proxy_agent import UserProxyAgent
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import pdb
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import replicate
|
||||
import requests
|
||||
from regex import R
|
||||
|
||||
from autogen.agentchat.agent import Agent
|
||||
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
|
||||
from autogen.code_utils import content_str
|
||||
from autogen.img_utils import get_image_data, llava_formater
|
||||
|
||||
try:
|
||||
from termcolor import colored
|
||||
except ImportError:
|
||||
|
||||
def colored(x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# we will override the following variables later.
|
||||
SEP = "###"
|
||||
|
||||
DEFAULT_LLAVA_SYS_MSG = "You are an AI agent and you can view images."
|
||||
|
||||
|
||||
class LLaVAAgent(MultimodalConversableAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
system_message: Optional[Tuple[str, List]] = DEFAULT_LLAVA_SYS_MSG,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
name (str): agent name.
|
||||
system_message (str): system message for the ChatCompletion inference.
|
||||
Please override this attribute if you want to reprogram the agent.
|
||||
**kwargs (dict): Please refer to other kwargs in
|
||||
[ConversableAgent](../conversable_agent#__init__).
|
||||
"""
|
||||
super().__init__(
|
||||
name,
|
||||
system_message=system_message,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
assert self.llm_config is not None, "llm_config must be provided."
|
||||
self.register_reply([Agent, None], reply_func=LLaVAAgent._image_reply, position=1)
|
||||
|
||||
def _image_reply(self, messages=None, sender=None, config=None):
|
||||
# Note: we did not use "llm_config" yet.
|
||||
|
||||
if all((messages is None, sender is None)):
|
||||
error_msg = f"Either {messages=} or {sender=} must be provided."
|
||||
logger.error(error_msg)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender]
|
||||
|
||||
# The formats for LLaVA and GPT are different. So, we manually handle them here.
|
||||
images = []
|
||||
prompt = content_str(self.system_message) + "\n"
|
||||
for msg in messages:
|
||||
role = "Human" if msg["role"] == "user" else "Assistant"
|
||||
# pdb.set_trace()
|
||||
images += [d["image_url"]["url"] for d in msg["content"] if d["type"] == "image_url"]
|
||||
content_prompt = content_str(msg["content"])
|
||||
prompt += f"{SEP}{role}: {content_prompt}\n"
|
||||
prompt += "\n" + SEP + "Assistant: "
|
||||
images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images]
|
||||
print(colored(prompt, "blue"))
|
||||
|
||||
out = ""
|
||||
retry = 10
|
||||
while len(out) == 0 and retry > 0:
|
||||
# image names will be inferred automatically from llava_call
|
||||
out = llava_call_binary(
|
||||
prompt=prompt,
|
||||
images=images,
|
||||
config_list=self.llm_config["config_list"],
|
||||
temperature=self.llm_config.get("temperature", 0.5),
|
||||
max_new_tokens=self.llm_config.get("max_new_tokens", 2000),
|
||||
)
|
||||
retry -= 1
|
||||
|
||||
assert out != "", "Empty response from LLaVA."
|
||||
|
||||
return True, out
|
||||
|
||||
|
||||
def _llava_call_binary_with_config(
|
||||
prompt: str, images: list, config: dict, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1
|
||||
):
|
||||
if config["base_url"].find("0.0.0.0") >= 0 or config["base_url"].find("localhost") >= 0:
|
||||
llava_mode = "local"
|
||||
else:
|
||||
llava_mode = "remote"
|
||||
|
||||
if llava_mode == "local":
|
||||
headers = {"User-Agent": "LLaVA Client"}
|
||||
pload = {
|
||||
"model": config["model"],
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
"stop": SEP,
|
||||
"images": images,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
config["base_url"].rstrip("/") + "/worker_generate_stream", headers=headers, json=pload, stream=False
|
||||
)
|
||||
|
||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["text"].split(SEP)[-1]
|
||||
elif llava_mode == "remote":
|
||||
# The Replicate version of the model only support 1 image for now.
|
||||
img = "data:image/jpeg;base64," + images[0]
|
||||
response = replicate.run(
|
||||
config["base_url"], input={"image": img, "prompt": prompt.replace("<image>", " "), "seed": seed}
|
||||
)
|
||||
# The yorickvp/llava-13b model can stream output as it's running.
|
||||
# The predict method returns an iterator, and you can iterate over that output.
|
||||
output = ""
|
||||
for item in response:
|
||||
# https://replicate.com/yorickvp/llava-13b/versions/2facb4a474a0462c15041b78b1ad70952ea46b5ec6ad29583c0b29dbd4249591/api#output-schema
|
||||
output += item
|
||||
|
||||
# Remove the prompt and the space.
|
||||
output = output.replace(prompt, "").strip().rstrip()
|
||||
return output
|
||||
|
||||
|
||||
def llava_call_binary(
|
||||
prompt: str, images: list, config_list: list, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1
|
||||
):
|
||||
# TODO 1: add caching around the LLaVA call to save compute and cost
|
||||
# TODO 2: add `seed` to ensure reproducibility. The seed is not working now.
|
||||
|
||||
for config in config_list:
|
||||
try:
|
||||
return _llava_call_binary_with_config(prompt, images, config, max_new_tokens, temperature, seed)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
def llava_call(prompt: str, llm_config: dict) -> str:
|
||||
"""
|
||||
Makes a call to the LLaVA service to generate text based on a given prompt
|
||||
"""
|
||||
|
||||
prompt, images = llava_formater(prompt, order_image_tokens=False)
|
||||
|
||||
for im in images:
|
||||
if len(im) == 0:
|
||||
raise RuntimeError("An image is empty!")
|
||||
|
||||
return llava_call_binary(
|
||||
prompt,
|
||||
images,
|
||||
config_list=llm_config["config_list"],
|
||||
max_new_tokens=llm_config.get("max_new_tokens", 2000),
|
||||
temperature=llm_config.get("temperature", 0.5),
|
||||
seed=llm_config.get("seed", None),
|
||||
)
|
|
@ -0,0 +1,107 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from autogen import OpenAIWrapper
|
||||
from autogen.agentchat import Agent, ConversableAgent
|
||||
from autogen.img_utils import gpt4v_formatter
|
||||
|
||||
try:
|
||||
from termcolor import colored
|
||||
except ImportError:
|
||||
|
||||
def colored(x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
|
||||
from autogen.code_utils import content_str
|
||||
|
||||
DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant."""
|
||||
|
||||
|
||||
class MultimodalConversableAgent(ConversableAgent):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
system_message: Optional[Union[str, List]] = DEFAULT_LMM_SYS_MSG,
|
||||
is_termination_msg: str = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
name (str): agent name.
|
||||
system_message (str): system message for the OpenAIWrapper inference.
|
||||
Please override this attribute if you want to reprogram the agent.
|
||||
**kwargs (dict): Please refer to other kwargs in
|
||||
[ConversableAgent](../conversable_agent#__init__).
|
||||
"""
|
||||
super().__init__(
|
||||
name,
|
||||
system_message,
|
||||
is_termination_msg=is_termination_msg,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.update_system_message(system_message)
|
||||
self._is_termination_msg = (
|
||||
is_termination_msg
|
||||
if is_termination_msg is not None
|
||||
else (lambda x: any([item["text"] == "TERMINATE" for item in x.get("content") if item["type"] == "text"]))
|
||||
)
|
||||
|
||||
@property
|
||||
def system_message(self) -> List:
|
||||
"""Return the system message."""
|
||||
return self._oai_system_message[0]["content"]
|
||||
|
||||
def update_system_message(self, system_message: Union[Dict, List, str]):
|
||||
"""Update the system message.
|
||||
|
||||
Args:
|
||||
system_message (str): system message for the OpenAIWrapper inference.
|
||||
"""
|
||||
self._oai_system_message[0]["content"] = self._message_to_dict(system_message)["content"]
|
||||
self._oai_system_message[0]["role"] = "system"
|
||||
|
||||
@staticmethod
|
||||
def _message_to_dict(message: Union[Dict, List, str]):
|
||||
"""Convert a message to a dictionary.
|
||||
|
||||
The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary.
|
||||
"""
|
||||
if isinstance(message, str):
|
||||
return {"content": gpt4v_formatter(message)}
|
||||
if isinstance(message, list):
|
||||
return {"content": message}
|
||||
else:
|
||||
return message
|
||||
|
||||
def _print_received_message(self, message: Union[Dict, str], sender: Agent):
|
||||
# print the message received
|
||||
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
|
||||
if message.get("role") == "function":
|
||||
func_print = f"***** Response from calling function \"{message['name']}\" *****"
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
print(content_str(message["content"]), flush=True)
|
||||
print(colored("*" * len(func_print), "green"), flush=True)
|
||||
else:
|
||||
content = message.get("content")
|
||||
if content is not None:
|
||||
if "context" in message:
|
||||
content = OpenAIWrapper.instantiate(
|
||||
content,
|
||||
message["context"],
|
||||
self.llm_config and self.llm_config.get("allow_format_str_template", False),
|
||||
)
|
||||
print(content_str(content), flush=True)
|
||||
if "function_call" in message:
|
||||
func_print = f"***** Suggested function Call: {message['function_call'].get('name', '(No function name found)')} *****"
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
print(
|
||||
"Arguments: \n",
|
||||
message["function_call"].get("arguments", "(No arguments found)"),
|
||||
flush=True,
|
||||
sep="",
|
||||
)
|
||||
print(colored("*" * len(func_print), "green"), flush=True)
|
||||
print("\n", "-" * 80, flush=True, sep="")
|
|
@ -1,14 +1,15 @@
|
|||
import subprocess
|
||||
import sys
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from typing import List, Dict, Tuple, Optional, Union, Callable
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from hashlib import md5
|
||||
import logging
|
||||
from autogen import oai
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
||||
from hashlib import md5
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from autogen import oai
|
||||
|
||||
try:
|
||||
import docker
|
||||
|
@ -29,6 +30,19 @@ PATH_SEPARATOR = WIN32 and "\\" or "/"
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def content_str(content: Union[str, List]) -> str:
|
||||
if type(content) is str:
|
||||
return content
|
||||
rst = ""
|
||||
for item in content:
|
||||
if item["type"] == "text":
|
||||
rst += item["text"]
|
||||
else:
|
||||
assert isinstance(item, dict) and item["type"] == "image_url", "Wrong content format."
|
||||
rst += "<image>"
|
||||
return rst
|
||||
|
||||
|
||||
def infer_lang(code):
|
||||
"""infer the language for the code.
|
||||
TODO: make it robust.
|
||||
|
@ -46,12 +60,13 @@ def infer_lang(code):
|
|||
|
||||
|
||||
def extract_code(
|
||||
text: str, pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
|
||||
text: Union[str, List], pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Extract code from a text.
|
||||
|
||||
Args:
|
||||
text (str): The text to extract code from.
|
||||
text (str or List): The content to extract code from. The content can be
|
||||
a string or a list, as returned by standard GPT or multimodal GPT.
|
||||
pattern (str, optional): The regular expression pattern for finding the
|
||||
code block. Defaults to CODE_BLOCK_PATTERN.
|
||||
detect_single_line_code (bool, optional): Enable the new feature for
|
||||
|
@ -62,6 +77,7 @@ def extract_code(
|
|||
If there is no code block in the input text, the language would be "unknown".
|
||||
If there is code block but the language is not specified, the language would be "".
|
||||
"""
|
||||
text = content_str(text)
|
||||
if not detect_single_line_code:
|
||||
match = re.findall(pattern, text, flags=re.DOTALL)
|
||||
return match if match else [(UNKNOWN, text)]
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
import base64
|
||||
import mimetypes
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_image_data(image_file: str, use_b64=True) -> bytes:
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
response = requests.get(image_file)
|
||||
content = response.content
|
||||
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
|
||||
return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
|
||||
else:
|
||||
image = Image.open(image_file).convert("RGB")
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
|
||||
if use_b64:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
return content
|
||||
|
||||
|
||||
def llava_formater(prompt: str, order_image_tokens: bool = False) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Formats the input prompt by replacing image tags and returns the new prompt along with image locations.
|
||||
|
||||
Parameters:
|
||||
- prompt (str): The input string that may contain image tags like <img ...>.
|
||||
- order_image_tokens (bool, optional): Whether to order the image tokens with numbers.
|
||||
It will be useful for GPT-4V. Defaults to False.
|
||||
|
||||
Returns:
|
||||
- Tuple[str, List[str]]: A tuple containing the formatted string and a list of images (loaded in b64 format).
|
||||
"""
|
||||
|
||||
# Initialize variables
|
||||
new_prompt = prompt
|
||||
image_locations = []
|
||||
images = []
|
||||
image_count = 0
|
||||
|
||||
# Regular expression pattern for matching <img ...> tags
|
||||
img_tag_pattern = re.compile(r"<img ([^>]+)>")
|
||||
|
||||
# Find all image tags
|
||||
for match in img_tag_pattern.finditer(prompt):
|
||||
image_location = match.group(1)
|
||||
|
||||
try:
|
||||
img_data = get_image_data(image_location)
|
||||
except Exception as e:
|
||||
# Remove the token
|
||||
print(f"Warning! Unable to load image from {image_location}, because of {e}")
|
||||
new_prompt = new_prompt.replace(match.group(0), "", 1)
|
||||
continue
|
||||
|
||||
image_locations.append(image_location)
|
||||
images.append(img_data)
|
||||
|
||||
# Increment the image count and replace the tag in the prompt
|
||||
new_token = f"<image {image_count}>" if order_image_tokens else "<image>"
|
||||
|
||||
new_prompt = new_prompt.replace(match.group(0), new_token, 1)
|
||||
image_count += 1
|
||||
|
||||
return new_prompt, images
|
||||
|
||||
|
||||
def convert_base64_to_data_uri(base64_image):
|
||||
def _get_mime_type_from_data_uri(base64_image):
|
||||
# Decode the base64 string
|
||||
image_data = base64.b64decode(base64_image)
|
||||
# Check the first few bytes for known signatures
|
||||
if image_data.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
elif image_data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
|
||||
return "image/gif"
|
||||
elif image_data.startswith(b"RIFF") and image_data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
return "image/jpeg" # use jpeg for unknown formats, best guess.
|
||||
|
||||
mime_type = _get_mime_type_from_data_uri(base64_image)
|
||||
data_uri = f"data:{mime_type};base64,{base64_image}"
|
||||
return data_uri
|
||||
|
||||
|
||||
def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
|
||||
"""
|
||||
Formats the input prompt by replacing image tags and returns a list of text and images.
|
||||
|
||||
Parameters:
|
||||
- prompt (str): The input string that may contain image tags like <img ...>.
|
||||
|
||||
Returns:
|
||||
- List[Union[str, dict]]: A list of alternating text and image dictionary items.
|
||||
"""
|
||||
output = []
|
||||
last_index = 0
|
||||
image_count = 0
|
||||
|
||||
# Regular expression pattern for matching <img ...> tags
|
||||
img_tag_pattern = re.compile(r"<img ([^>]+)>")
|
||||
|
||||
# Find all image tags
|
||||
for match in img_tag_pattern.finditer(prompt):
|
||||
image_location = match.group(1)
|
||||
|
||||
try:
|
||||
img_data = get_image_data(image_location)
|
||||
except Exception as e:
|
||||
# Warning and skip this token
|
||||
print(f"Warning! Unable to load image from {image_location}, because {e}")
|
||||
continue
|
||||
|
||||
# Add text before this image tag to output list
|
||||
output.append({"type": "text", "text": prompt[last_index : match.start()]})
|
||||
|
||||
# Add image data to output list
|
||||
output.append({"type": "image_url", "image_url": {"url": convert_base64_to_data_uri(img_data)}})
|
||||
|
||||
last_index = match.end()
|
||||
image_count += 1
|
||||
|
||||
# Add remaining text to output list
|
||||
output.append({"type": "text", "text": prompt[last_index:]})
|
||||
return output
|
||||
|
||||
|
||||
def extract_img_paths(paragraph: str) -> list:
|
||||
"""
|
||||
Extract image paths (URLs or local paths) from a text paragraph.
|
||||
|
||||
Parameters:
|
||||
paragraph (str): The input text paragraph.
|
||||
|
||||
Returns:
|
||||
list: A list of extracted image paths.
|
||||
"""
|
||||
# Regular expression to match image URLs and file paths
|
||||
img_path_pattern = re.compile(
|
||||
r"\b(?:http[s]?://\S+\.(?:jpg|jpeg|png|gif|bmp)|\S+\.(?:jpg|jpeg|png|gif|bmp))\b", re.IGNORECASE
|
||||
)
|
||||
|
||||
# Find all matches in the paragraph
|
||||
img_paths = re.findall(img_path_pattern, paragraph)
|
||||
return img_paths
|
||||
|
||||
|
||||
def _to_pil(data: str) -> Image.Image:
|
||||
"""
|
||||
Converts a base64 encoded image data string to a PIL Image object.
|
||||
|
||||
This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes,
|
||||
and finally creates and returns a PIL Image object from the BytesIO object.
|
||||
|
||||
Parameters:
|
||||
data (str): The base64 encoded image data string.
|
||||
|
||||
Returns:
|
||||
Image.Image: The PIL Image object created from the input data.
|
||||
"""
|
||||
return Image.open(BytesIO(base64.b64decode(data)))
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
6
setup.py
6
setup.py
|
@ -1,12 +1,12 @@
|
|||
import setuptools
|
||||
import os
|
||||
|
||||
import setuptools
|
||||
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
with open("README.md", "r", encoding="UTF-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
|
||||
# Get the code version
|
||||
version = {}
|
||||
with open(os.path.join(here, "autogen/version.py")) as fp:
|
||||
|
@ -22,7 +22,6 @@ install_requires = [
|
|||
"tiktoken",
|
||||
]
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name="pyautogen",
|
||||
version=__version__,
|
||||
|
@ -52,6 +51,7 @@ setuptools.setup(
|
|||
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
|
||||
"retrievechat": ["chromadb", "sentence_transformers", "pypdf", "ipython"],
|
||||
"teachable": ["chromadb"],
|
||||
"lmm": ["replicate", "pillow"],
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import autogen
|
||||
|
||||
try:
|
||||
from autogen.agentchat.contrib.llava_agent import (
|
||||
LLaVAAgent,
|
||||
_llava_call_binary_with_config,
|
||||
llava_call,
|
||||
llava_call_binary,
|
||||
)
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
skip = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestLLaVAAgent(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.agent = LLaVAAgent(
|
||||
name="TestAgent",
|
||||
llm_config={
|
||||
"timeout": 600,
|
||||
"seed": 42,
|
||||
"config_list": [{"model": "llava-fake", "base_url": "localhost:8000", "api_key": "Fake"}],
|
||||
},
|
||||
)
|
||||
|
||||
def test_init(self):
|
||||
self.assertIsInstance(self.agent, LLaVAAgent)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestLLavaCallBinaryWithConfig(unittest.TestCase):
|
||||
@patch("requests.post")
|
||||
def test_local_mode(self, mock_post):
|
||||
# Mocking the response of requests.post
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [b'{"text":"response text"}']
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Calling the function
|
||||
output = _llava_call_binary_with_config(
|
||||
prompt="Test Prompt",
|
||||
images=[],
|
||||
config={"base_url": "http://0.0.0.0/api", "model": "test-model"},
|
||||
max_new_tokens=1000,
|
||||
temperature=0.5,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
# Verifying the results
|
||||
self.assertEqual(output, "response text")
|
||||
mock_post.assert_called_once_with(
|
||||
"http://0.0.0.0/api/worker_generate_stream",
|
||||
headers={"User-Agent": "LLaVA Client"},
|
||||
json={
|
||||
"model": "test-model",
|
||||
"prompt": "Test Prompt",
|
||||
"max_new_tokens": 1000,
|
||||
"temperature": 0.5,
|
||||
"stop": "###",
|
||||
"images": [],
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
@patch("replicate.run")
|
||||
def test_remote_mode(self, mock_run):
|
||||
# Mocking the response of replicate.run
|
||||
mock_run.return_value = iter(["response ", "text"])
|
||||
|
||||
# Calling the function
|
||||
output = _llava_call_binary_with_config(
|
||||
prompt="Test Prompt",
|
||||
images=["image_data"],
|
||||
config={"base_url": "http://remote/api", "model": "test-model"},
|
||||
max_new_tokens=1000,
|
||||
temperature=0.5,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
# Verifying the results
|
||||
self.assertEqual(output, "response text")
|
||||
mock_run.assert_called_once_with(
|
||||
"http://remote/api",
|
||||
input={"image": "_data", "prompt": "Test Prompt", "seed": 1},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestLLavaCall(unittest.TestCase):
|
||||
@patch("autogen.agentchat.contrib.llava_agent.llava_formater")
|
||||
@patch("autogen.agentchat.contrib.llava_agent.llava_call_binary")
|
||||
def test_llava_call(self, mock_llava_call_binary, mock_llava_formater):
|
||||
# Set up the mocks
|
||||
mock_llava_formater.return_value = ("formatted prompt", ["image1", "image2"])
|
||||
mock_llava_call_binary.return_value = "Generated Text"
|
||||
|
||||
# Set up the llm_config dictionary
|
||||
llm_config = {
|
||||
"config_list": [{"api_key": "value", "base_url": "localhost:8000"}],
|
||||
"max_new_tokens": 2000,
|
||||
"temperature": 0.5,
|
||||
"seed": 1,
|
||||
}
|
||||
|
||||
# Call the function
|
||||
result = llava_call("Test Prompt", llm_config)
|
||||
|
||||
# Check the results
|
||||
mock_llava_formater.assert_called_once_with("Test Prompt", order_image_tokens=False)
|
||||
mock_llava_call_binary.assert_called_once_with(
|
||||
"formatted prompt",
|
||||
["image1", "image2"],
|
||||
config_list=llm_config["config_list"],
|
||||
max_new_tokens=2000,
|
||||
temperature=0.5,
|
||||
seed=1,
|
||||
)
|
||||
self.assertEqual(result, "Generated Text")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,83 @@
|
|||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import autogen
|
||||
from autogen.agentchat.agent import Agent
|
||||
|
||||
try:
|
||||
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
skip = False
|
||||
|
||||
|
||||
base64_encoded_image = (
|
||||
""
|
||||
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestMultimodalConversableAgent(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.agent = MultimodalConversableAgent(
|
||||
name="TestAgent",
|
||||
llm_config={
|
||||
"timeout": 600,
|
||||
"seed": 42,
|
||||
"config_list": [{"model": "gpt-4-vision-preview", "api_key": "sk-fake"}],
|
||||
},
|
||||
)
|
||||
|
||||
def test_system_message(self):
|
||||
# Test default system message
|
||||
self.assertEqual(
|
||||
self.agent.system_message,
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are a helpful AI assistant.",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Test updating system message
|
||||
new_message = f"We will discuss <img {base64_encoded_image}> in this conversation."
|
||||
self.agent.update_system_message(new_message)
|
||||
self.assertEqual(
|
||||
self.agent.system_message,
|
||||
[
|
||||
{"type": "text", "text": "We will discuss "},
|
||||
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
|
||||
{"type": "text", "text": " in this conversation."},
|
||||
],
|
||||
)
|
||||
|
||||
def test_message_to_dict(self):
|
||||
# Test string message
|
||||
message_str = "Hello"
|
||||
expected_dict = {"content": [{"type": "text", "text": "Hello"}]}
|
||||
self.assertDictEqual(self.agent._message_to_dict(message_str), expected_dict)
|
||||
|
||||
# Test list message
|
||||
message_list = [{"type": "text", "text": "Hello"}]
|
||||
expected_dict = {"content": message_list}
|
||||
self.assertDictEqual(self.agent._message_to_dict(message_list), expected_dict)
|
||||
|
||||
# Test dictionary message
|
||||
message_dict = {"content": [{"type": "text", "text": "Hello"}]}
|
||||
self.assertDictEqual(self.agent._message_to_dict(message_dict), message_dict)
|
||||
|
||||
def test_print_received_message(self):
|
||||
sender = Agent(name="SenderAgent")
|
||||
message_str = "Hello"
|
||||
self.agent._print_received_message = MagicMock() # Mocking print method to avoid actual print
|
||||
self.agent._print_received_message(message_str, sender)
|
||||
self.agent._print_received_message.assert_called_with(message_str, sender)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,16 +1,20 @@
|
|||
import sys
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
import autogen
|
||||
from autogen.code_utils import (
|
||||
PATH_SEPARATOR,
|
||||
UNKNOWN,
|
||||
extract_code,
|
||||
WIN32,
|
||||
content_str,
|
||||
execute_code,
|
||||
infer_lang,
|
||||
extract_code,
|
||||
improve_code,
|
||||
improve_function,
|
||||
PATH_SEPARATOR,
|
||||
WIN32,
|
||||
infer_lang,
|
||||
)
|
||||
|
||||
KEY_LOC = "notebook"
|
||||
|
@ -315,8 +319,36 @@ def _test_improve():
|
|||
f.write(improvement)
|
||||
|
||||
|
||||
class TestContentStr(unittest.TestCase):
|
||||
def test_string_content(self):
|
||||
self.assertEqual(content_str("simple string"), "simple string")
|
||||
|
||||
def test_list_of_text_content(self):
|
||||
content = [{"type": "text", "text": "hello"}, {"type": "text", "text": " world"}]
|
||||
self.assertEqual(content_str(content), "hello world")
|
||||
|
||||
def test_mixed_content(self):
|
||||
content = [{"type": "text", "text": "hello"}, {"type": "image_url", "url": "http://example.com/image.png"}]
|
||||
self.assertEqual(content_str(content), "hello<image>")
|
||||
|
||||
def test_invalid_content(self):
|
||||
content = [{"type": "text", "text": "hello"}, {"type": "wrong_type", "url": "http://example.com/image.png"}]
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
content_str(content)
|
||||
self.assertIn("Wrong content format", str(context.exception))
|
||||
|
||||
def test_empty_list(self):
|
||||
self.assertEqual(content_str([]), "")
|
||||
|
||||
def test_non_dict_in_list(self):
|
||||
content = ["string", {"type": "text", "text": "text"}]
|
||||
with self.assertRaises(TypeError):
|
||||
content_str(content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_infer_lang()
|
||||
# test_extract_code()
|
||||
test_execute_code()
|
||||
# test_find_code()
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,193 @@
|
|||
import base64
|
||||
import os
|
||||
import pdb
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
from autogen.img_utils import extract_img_paths, get_image_data, gpt4v_formatter, llava_formater
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
skip = False
|
||||
|
||||
|
||||
base64_encoded_image = (
|
||||
""
|
||||
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
raw_encoded_image = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAUAAAAFCAYAAACNbyblAAAAHElEQVQI12P4"
|
||||
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestGetImageData(unittest.TestCase):
|
||||
def test_http_image(self):
|
||||
with patch("requests.get") as mock_get:
|
||||
mock_response = requests.Response()
|
||||
mock_response.status_code = 200
|
||||
mock_response._content = b"fake image content"
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = get_image_data("http://example.com/image.png")
|
||||
self.assertEqual(result, base64.b64encode(b"fake image content").decode("utf-8"))
|
||||
|
||||
def test_base64_encoded_image(self):
|
||||
result = get_image_data(base64_encoded_image)
|
||||
self.assertEqual(result, base64_encoded_image.split(",", 1)[1])
|
||||
|
||||
def test_local_image(self):
|
||||
# Create a temporary file to simulate a local image file.
|
||||
temp_file = "_temp.png"
|
||||
|
||||
image = Image.new("RGB", (60, 30), color=(73, 109, 137))
|
||||
image.save(temp_file)
|
||||
|
||||
result = get_image_data(temp_file)
|
||||
with open(temp_file, "rb") as temp_image_file:
|
||||
temp_image_file.seek(0)
|
||||
expected_content = base64.b64encode(temp_image_file.read()).decode("utf-8")
|
||||
|
||||
self.assertEqual(result, expected_content)
|
||||
os.remove(temp_file)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestLlavaFormater(unittest.TestCase):
|
||||
def test_no_images(self):
|
||||
"""
|
||||
Test the llava_formater function with a prompt containing no images.
|
||||
"""
|
||||
prompt = "This is a test."
|
||||
expected_output = (prompt, [])
|
||||
result = llava_formater(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("autogen.img_utils.get_image_data")
|
||||
def test_with_images(self, mock_get_image_data):
|
||||
"""
|
||||
Test the llava_formater function with a prompt containing images.
|
||||
"""
|
||||
# Mock the get_image_data function to return a fixed string.
|
||||
mock_get_image_data.return_value = raw_encoded_image
|
||||
|
||||
prompt = "This is a test with an image <img http://example.com/image.png>."
|
||||
expected_output = ("This is a test with an image <image>.", [raw_encoded_image])
|
||||
result = llava_formater(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("autogen.img_utils.get_image_data")
|
||||
def test_with_ordered_images(self, mock_get_image_data):
|
||||
"""
|
||||
Test the llava_formater function with ordered image tokens.
|
||||
"""
|
||||
# Mock the get_image_data function to return a fixed string.
|
||||
mock_get_image_data.return_value = raw_encoded_image
|
||||
|
||||
prompt = "This is a test with an image <img http://example.com/image.png>."
|
||||
expected_output = ("This is a test with an image <image 0>.", [raw_encoded_image])
|
||||
result = llava_formater(prompt, order_image_tokens=True)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestGpt4vFormatter(unittest.TestCase):
|
||||
def test_no_images(self):
|
||||
"""
|
||||
Test the gpt4v_formatter function with a prompt containing no images.
|
||||
"""
|
||||
prompt = "This is a test."
|
||||
expected_output = [{"type": "text", "text": prompt}]
|
||||
result = gpt4v_formatter(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("autogen.img_utils.get_image_data")
|
||||
def test_with_images(self, mock_get_image_data):
|
||||
"""
|
||||
Test the gpt4v_formatter function with a prompt containing images.
|
||||
"""
|
||||
# Mock the get_image_data function to return a fixed string.
|
||||
mock_get_image_data.return_value = raw_encoded_image
|
||||
|
||||
prompt = "This is a test with an image <img http://example.com/image.png>."
|
||||
expected_output = [
|
||||
{"type": "text", "text": "This is a test with an image "},
|
||||
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
|
||||
{"type": "text", "text": "."},
|
||||
]
|
||||
result = gpt4v_formatter(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("autogen.img_utils.get_image_data")
|
||||
def test_multiple_images(self, mock_get_image_data):
|
||||
"""
|
||||
Test the gpt4v_formatter function with a prompt containing multiple images.
|
||||
"""
|
||||
# Mock the get_image_data function to return a fixed string.
|
||||
mock_get_image_data.return_value = raw_encoded_image
|
||||
|
||||
prompt = (
|
||||
"This is a test with images <img http://example.com/image1.png> and <img http://example.com/image2.png>."
|
||||
)
|
||||
expected_output = [
|
||||
{"type": "text", "text": "This is a test with images "},
|
||||
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
|
||||
{"type": "text", "text": " and "},
|
||||
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
|
||||
{"type": "text", "text": "."},
|
||||
]
|
||||
result = gpt4v_formatter(prompt)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
class TestExtractImgPaths(unittest.TestCase):
|
||||
def test_no_images(self):
|
||||
"""
|
||||
Test the extract_img_paths function with a paragraph containing no images.
|
||||
"""
|
||||
paragraph = "This is a test paragraph with no images."
|
||||
expected_output = []
|
||||
result = extract_img_paths(paragraph)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
def test_with_images(self):
|
||||
"""
|
||||
Test the extract_img_paths function with a paragraph containing images.
|
||||
"""
|
||||
paragraph = (
|
||||
"This is a test paragraph with images http://example.com/image1.jpg and http://example.com/image2.png."
|
||||
)
|
||||
expected_output = ["http://example.com/image1.jpg", "http://example.com/image2.png"]
|
||||
result = extract_img_paths(paragraph)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
def test_mixed_case(self):
|
||||
"""
|
||||
Test the extract_img_paths function with mixed case image extensions.
|
||||
"""
|
||||
paragraph = "Mixed case extensions http://example.com/image.JPG and http://example.com/image.Png."
|
||||
expected_output = ["http://example.com/image.JPG", "http://example.com/image.Png"]
|
||||
result = extract_img_paths(paragraph)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
def test_local_paths(self):
|
||||
"""
|
||||
Test the extract_img_paths function with local file paths.
|
||||
"""
|
||||
paragraph = "Local paths image1.jpeg and image2.GIF."
|
||||
expected_output = ["image1.jpeg", "image2.GIF"]
|
||||
result = extract_img_paths(paragraph)
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Binary file not shown.
After Width: | Height: | Size: 2.3 MiB |
|
@ -0,0 +1,77 @@
|
|||
---
|
||||
title: Multimodal with GPT-4V and LLaVA
|
||||
authors: beibinli
|
||||
tags: [LMM, multimodal]
|
||||
---
|
||||
|
||||
![LMM Teaser](img/teaser.png)
|
||||
|
||||
**In Brief:**
|
||||
* Introducing the **Multimodal Conversable Agent** and the **LLaVA Agent** to enhance LMM functionalities.
|
||||
* Users can input text and images simultaneously using the `<img img_path>` tag to specify image loading.
|
||||
* Demonstrated through the [GPT-4V notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_gpt-4v.ipynb).
|
||||
* Demonstrated through the [LLaVA notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb).
|
||||
|
||||
## Introduction
|
||||
Large multimodal models (LMMs) augment large language models (LLMs) with the ability to process multi-sensory data.
|
||||
|
||||
This blog post and the latest AutoGen update concentrate on visual comprehension. Users can input images, pose questions about them, and receive text-based responses from these LMMs.
|
||||
We support the `gpt-4-vision-preview` model from OpenAI and `LLaVA` model from Microsoft now.
|
||||
|
||||
Here, we emphasize the **Multimodal Conversable Agent** and the **LLaVA Agent** due to their growing popularity.
|
||||
GPT-4V represents the forefront in image comprehension, while LLaVA is an efficient model, fine-tuned from LLama-2.
|
||||
|
||||
## Installation
|
||||
Incorporate the `lmm` feature during AutoGen installation:
|
||||
|
||||
```bash
|
||||
pip install "pyautogen[lmm]"
|
||||
```
|
||||
|
||||
Subsequently, import the **Multimodal Conversable Agent** or **LLaVA Agent** from AutoGen:
|
||||
|
||||
```python
|
||||
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent # for GPT-4V
|
||||
from autogen.agentchat.contrib.llava_agent import LLaVAAgent # for LLaVA
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
A simple syntax has been defined to incorporate both messages and images within a single string.
|
||||
|
||||
Example of an in-context learning prompt:
|
||||
|
||||
```python
|
||||
prompt = """You are now an image classifier for facial expressions. Here are
|
||||
some examples.
|
||||
|
||||
<img happy.jpg> depicts a happy expression.
|
||||
<img http://some_location.com/sad.jpg> represents a sad expression.
|
||||
<img obama.jpg> portrays a neutral expression.
|
||||
|
||||
Now, identify the facial expression of this individual: <img unknown.png>
|
||||
"""
|
||||
|
||||
agent = MultimodalConversableAgent()
|
||||
user = UserProxyAgent()
|
||||
user.initiate_chat(agent, message=prompt)
|
||||
```
|
||||
|
||||
The `MultimodalConversableAgent` interprets the input prompt, extracting images from local or internet sources.
|
||||
|
||||
## Advanced Usage
|
||||
Similar to other AutoGen agents, multimodal agents support multi-round dialogues with other agents, code generation, factual queries, and management via a GroupChat interface.
|
||||
|
||||
For example, the `FigureCreator` in our [GPT-4V notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_gpt-4v.ipynb) and [LLaVA notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb) integrates two agents: a coder (an AssistantAgent) and critics (a multimodal agent).
|
||||
The coder drafts Python code for visualizations, while the critics provide insights for enhancement. Collaboratively, these agents aim to refine visual outputs.
|
||||
With `human_input_mode=ALWAYS`, you can also contribute suggestions for better visualizations.
|
||||
|
||||
## Reference
|
||||
- [GPT-4V System Card](https://openai.com/research/gpt-4v-system-card)
|
||||
- [LLaVA GitHub](https://github.com/haotian-liu/LLaVA)
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
For further inquiries or suggestions, please open an issue in the [AutoGen repository](https://github.com/microsoft/autogen/) or contact me directly at beibin.li@microsoft.com.
|
||||
|
||||
AutoGen will continue to evolve, incorporating more multimodal functionalities such as DALLE model integration, audio interaction, and video comprehension. Stay tuned for these exciting developments.
|
|
@ -33,3 +33,9 @@ rickyloynd-microsoft:
|
|||
title: Senior Research Engineer at Microsoft
|
||||
url: https://github.com/rickyloynd-microsoft
|
||||
image_url: https://github.com/rickyloynd-microsoft.png
|
||||
|
||||
beibinli:
|
||||
name: Beibin Li
|
||||
title: Senior Research Engineer at Microsoft
|
||||
url: https://github.com/beibinli
|
||||
image_url: https://github.com/beibinli.png
|
||||
|
|
|
@ -115,6 +115,18 @@ Example notebooks:
|
|||
[Automated Code Generation and Question Answering with Qdrant based Retrieval Augmented Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_qdrant_RetrieveChat.ipynb)
|
||||
|
||||
|
||||
- #### Large Multimodal Model (LMM) Agents
|
||||
|
||||
We offered Multimodal Conversable Agent and LLaVA Agent. Please install with the [lmm] option to use it.
|
||||
```bash
|
||||
pip install "pyautogen[lmm]"
|
||||
```
|
||||
|
||||
Example notebooks:
|
||||
|
||||
[LLaVA Agent](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb)
|
||||
|
||||
|
||||
- #### mathchat
|
||||
|
||||
`pyautogen<0.2` offers an experimental agent for math problem solving. Please install with the [mathchat] option to use it.
|
||||
|
|
Loading…
Reference in New Issue