[Feature] Adds Image Generation Capability 2.0 (#1907)

* adds image generation capability

* add todo

* readded cache

* wip

* fix content str bugs

* removed todo: delete imshow

* wip

* fix circular imports

* add notebook

* improve prompt

* improved text analyzer + notebook

* notebook update

* improve notebook

* smaller notebook size

* made changes to the wrong branch :(

* resolve comments + 1

* adds doc strings

* adds cache doc string

* adds doc string to add_to_agent

* adds doc string to ImageGeneration

* instructions are not configurable

* removed unnecessary imports

* changed doc string location

* more doc strings

* improves testability

* adds tests

* adds cache test

* added test to github workflow

* compatible llm config format

* configurable reply function position

* skip_openai + better comments

* fix test

* fix test?

* please fix test?

* last fix test?

* remove type hint

* skip cache test

* adds mock api key

* dalle-2 test

* fix dalle config

* use apu key function

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Wael Karkoub 2024-03-15 23:11:53 +01:00 committed by GitHub
parent ea2c1b270e
commit c5536ee92b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 895 additions and 3 deletions

View File

@ -299,3 +299,38 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
ImageGen:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.12"]
runs-on: ${{ matrix.os }}
environment: openai1
steps:
# checkout to pr branch
- name: Checkout
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies
run: |
docker --version
python -m pip install --upgrade pip wheel
pip install -e .[lmm]
python -c "import autogen"
pip install coverage pytest
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_image_generation_capability.py
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests

View File

@ -15,8 +15,9 @@ on:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions: {}
# actions: read
permissions:
{}
# actions: read
# checks: read
# contents: read
# deployments: read
@ -246,7 +247,7 @@ jobs:
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py --skip-openai
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3

View File

@ -0,0 +1,291 @@
import re
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union
from openai import OpenAI
from PIL.Image import Image
from autogen import Agent, ConversableAgent, code_utils
from autogen.cache import Cache
from autogen.agentchat.contrib import img_utils
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
SYSTEM_MESSAGE = "You've been given the special ability to generate images."
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
DO NOT include any advice. RESPOND like the following example:
EXAMPLE: Blue background, 3D shapes, ...
"""
class ImageGenerator(Protocol):
"""This class defines an interface for image generators.
Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
input and returns a PIL Image object.
NOTE: Current implementation does not allow you to edit a previously existing image.
"""
def generate_image(self, prompt: str) -> Image:
"""Generates an image based on the provided prompt.
Args:
prompt: A string describing the desired image.
Returns:
A PIL Image object representing the generated image.
Raises:
ValueError: If the image generation fails.
"""
...
def cache_key(self, prompt: str) -> str:
"""Generates a unique cache key for the given prompt.
This key can be used to store and retrieve generated images based on the prompt.
Args:
prompt: A string describing the desired image.
Returns:
A unique string that can be used as a cache key.
"""
...
class DalleImageGenerator:
"""Generates images using OpenAI's DALL-E models.
This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.
Note: Current implementation does not allow you to edit a previously existing image.
"""
def __init__(
self,
llm_config: Dict,
resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
quality: Literal["standard", "hd"] = "standard",
num_images: int = 1,
):
"""
Args:
llm_config (dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
num_images (int): The number of images to generate.
"""
config_list = llm_config["config_list"]
_validate_dalle_model(config_list[0]["model"])
_validate_resolution_format(resolution)
self._model = config_list[0]["model"]
self._resolution = resolution
self._quality = quality
self._num_images = num_images
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
def generate_image(self, prompt: str) -> Image:
response = self._dalle_client.images.generate(
model=self._model,
prompt=prompt,
size=self._resolution,
quality=self._quality,
n=self._num_images,
)
image_url = response.data[0].url
if image_url is None:
raise ValueError("Failed to generate image.")
return img_utils.get_pil_image(image_url)
def cache_key(self, prompt: str) -> str:
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
return ",".join([str(k) for k in keys])
class ImageGeneration(AgentCapability):
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
extract relevant details.
2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
3. Optionally caches generated images for faster retrieval in future conversations.
NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
message received by the agent.
Example:
```python
import autogen
from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration
# Assuming you have llm configs configured for the LLMs you want to use and Dalle.
# Create the agent
agent = autogen.ConversableAgent(
name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
)
# Create an ImageGenerator with desired settings
dalle_gen = generate_images.DalleImageGenerator(llm_config={...})
# Add the ImageGeneration capability to the agent
agent.add_capability(ImageGeneration(image_generator=dalle_gen))
```
"""
def __init__(
self,
image_generator: ImageGenerator,
cache: Optional[Cache] = None,
text_analyzer_llm_config: Optional[Dict] = None,
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
verbosity: int = 0,
register_reply_position: int = 2,
):
"""
Args:
image_generator (ImageGenerator): The image generator you would like to use to generate images.
cache (None or Cache): The cache client to use to store and retrieve generated images. If None,
no caching will be used.
text_analyzer_llm_config (Dict or None): The LLM config for the text analyzer. If None, the LLM config will
be retrieved from the agent you're adding the ability to.
text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
incoming messages and extract the prompt for image generation. The default instructions focus on
summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
extraction.
Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
analyzer llm calls will be silent if verbosity is less than 2.
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
This capability registers a new reply function to handle messages with image generation requests.
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
"""
self._image_generator = image_generator
self._cache = cache
self._text_analyzer_llm_config = text_analyzer_llm_config
self._text_analyzer_instructions = text_analyzer_instructions
self._verbosity = verbosity
self._register_reply_position = register_reply_position
self._agent: Optional[ConversableAgent] = None
self._text_analyzer: Optional[TextAnalyzerAgent] = None
def add_to_agent(self, agent: ConversableAgent):
"""Adds the Image Generation capability to the specified ConversableAgent.
This function performs the following modifications to the agent:
1. Registers a reply function: A new reply function is registered with the agent to handle messages that
potentially request image generation. This function analyzes the message and triggers image generation if
necessary.
2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
3. Updates System Message: The agent's system message is updated to include a message indicating the
capability to generate images has been added.
4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
capability. This might be helpful in certain use cases, like group chats.
Args:
agent (ConversableAgent): The ConversableAgent to add the capability to.
"""
self._agent = agent
agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)
self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)
agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
agent.description += "\n" + DESCRIPTION_MESSAGE
def _image_gen_reply(
self,
recipient: ConversableAgent,
messages: Optional[List[Dict]],
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
if messages is None:
return False, None
last_message = code_utils.content_str(messages[-1]["content"])
if not last_message:
return False, None
if self._should_generate_image(last_message):
prompt = self._extract_prompt(last_message)
image = self._cache_get(prompt)
if image is None:
image = self._image_generator.generate_image(prompt)
self._cache_set(prompt, image)
return True, self._generate_content_message(prompt, image)
else:
return False, None
def _should_generate_image(self, message: str) -> bool:
assert self._text_analyzer is not None
instructions = """
Does any part of the TEXT ask the agent to generate an image?
The TEXT must explicitly mention that the image must be generated.
Answer with just one word, yes or no.
"""
analysis = self._text_analyzer.analyze_text(message, instructions)
return "yes" in self._extract_analysis(analysis).lower()
def _extract_prompt(self, last_message) -> str:
assert self._text_analyzer is not None
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
return self._extract_analysis(analysis)
def _cache_get(self, prompt: str) -> Optional[Image]:
if self._cache:
key = self._image_generator.cache_key(prompt)
cached_value = self._cache.get(key)
if cached_value:
return img_utils.get_pil_image(cached_value)
def _cache_set(self, prompt: str, image: Image):
if self._cache:
key = self._image_generator.cache_key(prompt)
self._cache.set(key, img_utils.pil_to_data_uri(image))
def _extract_analysis(self, analysis: Union[str, Dict, None]) -> str:
if isinstance(analysis, Dict):
return code_utils.content_str(analysis["content"])
else:
return code_utils.content_str(analysis)
def _generate_content_message(self, prompt: str, image: Image) -> Dict[str, Any]:
return {
"content": [
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
]
}
### Helpers
def _validate_resolution_format(resolution: str):
"""Checks if a string is in a valid resolution format (e.g., "1024x768")."""
pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
matched_resolution = re.match(pattern, resolution)
if matched_resolution is None:
raise ValueError(f"Invalid resolution format: {resolution}")
def _validate_dalle_model(model: str):
if model not in ["dall-e-3", "dall-e-2"]:
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,224 @@
import itertools
import os
import tempfile
from typing import Any, Dict, Tuple
import pytest
import sys
from autogen import code_utils
from autogen.agentchat.conversable_agent import ConversableAgent
from autogen.agentchat.user_proxy_agent import UserProxyAgent
from autogen.cache.cache import Cache
from autogen.oai import openai_utils
try:
from PIL import Image
from autogen.agentchat.contrib.capabilities import generate_images
from autogen.agentchat.contrib.img_utils import get_pil_image
except ImportError:
skip_requirement = True
else:
skip_requirement = False
sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402
filter_dict = {"model": ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]}
RESOLUTIONS = ["256x256", "512x512", "1024x1024"]
QUALITIES = ["standard", "hd"]
PROMPTS = [
"Generate an image of a robot holding a 'I Love Autogen' sign",
"Generate an image of a dog holding a 'I Love Autogen' sign",
]
class _TestImageGenerator:
def __init__(self, image):
self._image = image
def generate_image(self, prompt: str):
return self._image
def cache_key(self, prompt: str):
return prompt
def create_test_agent(name: str = "test_agent", default_auto_reply: str = "") -> ConversableAgent:
return ConversableAgent(name=name, llm_config=False, default_auto_reply=default_auto_reply)
def dalle_image_generator(dalle_config: Dict[str, Any], resolution: str, quality: str):
return generate_images.DalleImageGenerator(dalle_config, resolution=resolution, quality=quality, num_images=1)
def api_key():
return MOCK_OPEN_AI_API_KEY if skip_openai else os.environ.get("OPENAI_API_KEY")
@pytest.fixture
def dalle_config() -> Dict[str, Any]:
config_list = openai_utils.config_list_from_models(model_list=["dall-e-2"], exclude="aoai")
if not config_list:
config_list = [{"model": "dall-e-2", "api_key": api_key()}]
return {"config_list": config_list, "timeout": 120, "cache_seed": None}
@pytest.fixture
def gpt3_config() -> Dict[str, Any]:
config_list = [
{
"model": "gpt-35-turbo-16k",
"api_key": api_key(),
},
{
"model": "gpt-3.5-turbo-16k",
"api_key": api_key(),
},
]
return {"config_list": config_list, "timeout": 120, "cache_seed": None}
@pytest.fixture
def image_gen_capability():
image_generator = _TestImageGenerator(Image.new("RGB", (256, 256)))
return generate_images.ImageGeneration(image_generator)
@pytest.mark.skipif(skip_openai, reason="Requested to skip.")
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
def test_dalle_image_generator(dalle_config: Dict[str, Any]):
"""Tests DalleImageGenerator capability to generate images by calling the OpenAI API."""
dalle_generator = dalle_image_generator(dalle_config, RESOLUTIONS[0], QUALITIES[0])
image = dalle_generator.generate_image(PROMPTS[0])
assert isinstance(image, Image.Image)
# Using cartesian product to generate all possible combinations of resolution, quality, and prompt
@pytest.mark.parametrize("gen_config_1", itertools.product(RESOLUTIONS, QUALITIES, PROMPTS))
@pytest.mark.parametrize("gen_config_2", itertools.product(RESOLUTIONS, QUALITIES, PROMPTS))
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
def test_dalle_image_generator_cache_key(
dalle_config: Dict[str, Any], gen_config_1: Tuple[str, str, str], gen_config_2: Tuple[str, str, str]
):
"""Tests if DalleImageGenerator creates unique cache keys.
Args:
dalle_config: The LLM config for the DalleImageGenerator.
gen_config_1: A tuple containing the resolution, quality, and prompt for the first image generator.
gen_config_2: A tuple containing the resolution, quality, and prompt for the second image generator.
"""
dalle_generator_1 = dalle_image_generator(dalle_config, resolution=gen_config_1[0], quality=gen_config_1[1])
dalle_generator_2 = dalle_image_generator(dalle_config, resolution=gen_config_2[0], quality=gen_config_2[1])
cache_key_1 = dalle_generator_1.cache_key(gen_config_1[2])
cache_key_2 = dalle_generator_2.cache_key(gen_config_2[2])
if gen_config_1 == gen_config_2:
assert cache_key_1 == cache_key_2
else:
assert cache_key_1 != cache_key_2
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
def test_image_generation_capability_positive(monkeypatch, image_gen_capability):
"""Tests ImageGeneration capability to generate images by calling the ImageGenerator.
This tests if the message is asking the agent to generate an image.
"""
auto_reply = "Didn't need to generate an image."
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent to make API calls
# Improves reproducibility and falkiness of the test
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: True)
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
user = UserProxyAgent("user", human_input_mode="NEVER")
agent = create_test_agent(default_auto_reply=auto_reply)
image_gen_capability.add_to_agent(agent)
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
last_message = agent.last_message()
assert last_message
processed_message = code_utils.content_str(last_message["content"])
assert "<image>" in processed_message
assert auto_reply not in processed_message
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
def test_image_generation_capability_negative(monkeypatch, image_gen_capability):
"""Tests ImageGeneration capability to generate images by calling the ImageGenerator.
This tests if the message is not asking the agent to generate an image.
"""
auto_reply = "Didn't need to generate an image."
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent making API calls.
# Improves reproducibility and flakiness of the test.
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: False)
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
user = UserProxyAgent("user", human_input_mode="NEVER")
agent = ConversableAgent("test_agent", llm_config=False, default_auto_reply=auto_reply)
image_gen_capability.add_to_agent(agent)
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
last_message = agent.last_message()
assert last_message
processed_message = code_utils.content_str(last_message["content"])
assert "<image>" not in processed_message
assert auto_reply == processed_message
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
def test_image_generation_capability_cache(monkeypatch):
"""Tests ImageGeneration capability to cache the generated images."""
test_image_size = (256, 256)
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent making API calls.
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: True)
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
with tempfile.TemporaryDirectory() as temp_dir:
cache = Cache.disk(cache_path_root=temp_dir)
user = UserProxyAgent("user", human_input_mode="NEVER")
agent = create_test_agent()
test_image_generator = _TestImageGenerator(Image.new("RGB", test_image_size))
image_gen_capability = generate_images.ImageGeneration(test_image_generator, cache=cache)
image_gen_capability.add_to_agent(agent)
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
# Checking if the image has been cached by creating a new agent with a different image generator.
agent = create_test_agent(name="test_agent_2")
test_image_generator = _TestImageGenerator(Image.new("RGB", (512, 512)))
image_gen_capability = generate_images.ImageGeneration(test_image_generator, cache=cache)
image_gen_capability.add_to_agent(agent)
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
last_message = agent.last_message()
assert last_message
image_dict = [image for image in last_message["content"] if image["type"] == "image_url"]
image = get_pil_image(image_dict[0]["image_url"]["url"])
assert image.size == test_image_size
if __name__ == "__main__":
test_dalle_image_generator(
dalle_config={"config_list": openai_utils.config_list_from_models(model_list=["dall-e-2"], exclude="aoai")}
)