mirror of https://github.com/microsoft/autogen.git
[Feature] Adds Image Generation Capability 2.0 (#1907)
* adds image generation capability * add todo * readded cache * wip * fix content str bugs * removed todo: delete imshow * wip * fix circular imports * add notebook * improve prompt * improved text analyzer + notebook * notebook update * improve notebook * smaller notebook size * made changes to the wrong branch :( * resolve comments + 1 * adds doc strings * adds cache doc string * adds doc string to add_to_agent * adds doc string to ImageGeneration * instructions are not configurable * removed unnecessary imports * changed doc string location * more doc strings * improves testability * adds tests * adds cache test * added test to github workflow * compatible llm config format * configurable reply function position * skip_openai + better comments * fix test * fix test? * please fix test? * last fix test? * remove type hint * skip cache test * adds mock api key * dalle-2 test * fix dalle config * use apu key function --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
parent
ea2c1b270e
commit
c5536ee92b
|
@ -299,3 +299,38 @@ jobs:
|
|||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
ImageGen:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.12"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
environment: openai1
|
||||
steps:
|
||||
# checkout to pr branch
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install packages and dependencies
|
||||
run: |
|
||||
docker --version
|
||||
python -m pip install --upgrade pip wheel
|
||||
pip install -e .[lmm]
|
||||
python -c "import autogen"
|
||||
pip install coverage pytest
|
||||
- name: Coverage
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
run: |
|
||||
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_image_generation_capability.py
|
||||
coverage xml
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: ./coverage.xml
|
||||
flags: unittests
|
||||
|
|
|
@ -15,8 +15,9 @@ on:
|
|||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
permissions: {}
|
||||
# actions: read
|
||||
permissions:
|
||||
{}
|
||||
# actions: read
|
||||
# checks: read
|
||||
# contents: read
|
||||
# deployments: read
|
||||
|
@ -246,7 +247,7 @@ jobs:
|
|||
- name: Coverage
|
||||
run: |
|
||||
pip install coverage>=5.3
|
||||
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py --skip-openai
|
||||
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py --skip-openai
|
||||
coverage xml
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
|
|
|
@ -0,0 +1,291 @@
|
|||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union
|
||||
|
||||
from openai import OpenAI
|
||||
from PIL.Image import Image
|
||||
|
||||
from autogen import Agent, ConversableAgent, code_utils
|
||||
from autogen.cache import Cache
|
||||
from autogen.agentchat.contrib import img_utils
|
||||
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
|
||||
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
|
||||
|
||||
SYSTEM_MESSAGE = "You've been given the special ability to generate images."
|
||||
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."
|
||||
|
||||
PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
|
||||
DO NOT include any advice. RESPOND like the following example:
|
||||
EXAMPLE: Blue background, 3D shapes, ...
|
||||
"""
|
||||
|
||||
|
||||
class ImageGenerator(Protocol):
|
||||
"""This class defines an interface for image generators.
|
||||
|
||||
Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
|
||||
input and returns a PIL Image object.
|
||||
|
||||
NOTE: Current implementation does not allow you to edit a previously existing image.
|
||||
"""
|
||||
|
||||
def generate_image(self, prompt: str) -> Image:
|
||||
"""Generates an image based on the provided prompt.
|
||||
|
||||
Args:
|
||||
prompt: A string describing the desired image.
|
||||
|
||||
Returns:
|
||||
A PIL Image object representing the generated image.
|
||||
|
||||
Raises:
|
||||
ValueError: If the image generation fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def cache_key(self, prompt: str) -> str:
|
||||
"""Generates a unique cache key for the given prompt.
|
||||
|
||||
This key can be used to store and retrieve generated images based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: A string describing the desired image.
|
||||
|
||||
Returns:
|
||||
A unique string that can be used as a cache key.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DalleImageGenerator:
|
||||
"""Generates images using OpenAI's DALL-E models.
|
||||
|
||||
This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
|
||||
models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.
|
||||
|
||||
Note: Current implementation does not allow you to edit a previously existing image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: Dict,
|
||||
resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
|
||||
quality: Literal["standard", "hd"] = "standard",
|
||||
num_images: int = 1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
llm_config (dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
|
||||
resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
|
||||
quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
|
||||
num_images (int): The number of images to generate.
|
||||
"""
|
||||
config_list = llm_config["config_list"]
|
||||
_validate_dalle_model(config_list[0]["model"])
|
||||
_validate_resolution_format(resolution)
|
||||
|
||||
self._model = config_list[0]["model"]
|
||||
self._resolution = resolution
|
||||
self._quality = quality
|
||||
self._num_images = num_images
|
||||
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])
|
||||
|
||||
def generate_image(self, prompt: str) -> Image:
|
||||
response = self._dalle_client.images.generate(
|
||||
model=self._model,
|
||||
prompt=prompt,
|
||||
size=self._resolution,
|
||||
quality=self._quality,
|
||||
n=self._num_images,
|
||||
)
|
||||
|
||||
image_url = response.data[0].url
|
||||
if image_url is None:
|
||||
raise ValueError("Failed to generate image.")
|
||||
|
||||
return img_utils.get_pil_image(image_url)
|
||||
|
||||
def cache_key(self, prompt: str) -> str:
|
||||
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
|
||||
return ",".join([str(k) for k in keys])
|
||||
|
||||
|
||||
class ImageGeneration(AgentCapability):
|
||||
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.
|
||||
|
||||
1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
|
||||
extract relevant details.
|
||||
2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
|
||||
3. Optionally caches generated images for faster retrieval in future conversations.
|
||||
|
||||
NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
|
||||
message received by the agent.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import autogen
|
||||
from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration
|
||||
|
||||
# Assuming you have llm configs configured for the LLMs you want to use and Dalle.
|
||||
# Create the agent
|
||||
agent = autogen.ConversableAgent(
|
||||
name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
|
||||
)
|
||||
|
||||
# Create an ImageGenerator with desired settings
|
||||
dalle_gen = generate_images.DalleImageGenerator(llm_config={...})
|
||||
|
||||
# Add the ImageGeneration capability to the agent
|
||||
agent.add_capability(ImageGeneration(image_generator=dalle_gen))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_generator: ImageGenerator,
|
||||
cache: Optional[Cache] = None,
|
||||
text_analyzer_llm_config: Optional[Dict] = None,
|
||||
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
|
||||
verbosity: int = 0,
|
||||
register_reply_position: int = 2,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
image_generator (ImageGenerator): The image generator you would like to use to generate images.
|
||||
cache (None or Cache): The cache client to use to store and retrieve generated images. If None,
|
||||
no caching will be used.
|
||||
text_analyzer_llm_config (Dict or None): The LLM config for the text analyzer. If None, the LLM config will
|
||||
be retrieved from the agent you're adding the ability to.
|
||||
text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
|
||||
incoming messages and extract the prompt for image generation. The default instructions focus on
|
||||
summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
|
||||
extraction.
|
||||
Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
|
||||
verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
|
||||
analyzer llm calls will be silent if verbosity is less than 2.
|
||||
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
|
||||
This capability registers a new reply function to handle messages with image generation requests.
|
||||
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
|
||||
"""
|
||||
self._image_generator = image_generator
|
||||
self._cache = cache
|
||||
self._text_analyzer_llm_config = text_analyzer_llm_config
|
||||
self._text_analyzer_instructions = text_analyzer_instructions
|
||||
self._verbosity = verbosity
|
||||
self._register_reply_position = register_reply_position
|
||||
|
||||
self._agent: Optional[ConversableAgent] = None
|
||||
self._text_analyzer: Optional[TextAnalyzerAgent] = None
|
||||
|
||||
def add_to_agent(self, agent: ConversableAgent):
|
||||
"""Adds the Image Generation capability to the specified ConversableAgent.
|
||||
|
||||
This function performs the following modifications to the agent:
|
||||
|
||||
1. Registers a reply function: A new reply function is registered with the agent to handle messages that
|
||||
potentially request image generation. This function analyzes the message and triggers image generation if
|
||||
necessary.
|
||||
2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
|
||||
3. Updates System Message: The agent's system message is updated to include a message indicating the
|
||||
capability to generate images has been added.
|
||||
4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
|
||||
capability. This might be helpful in certain use cases, like group chats.
|
||||
|
||||
Args:
|
||||
agent (ConversableAgent): The ConversableAgent to add the capability to.
|
||||
"""
|
||||
self._agent = agent
|
||||
|
||||
agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)
|
||||
|
||||
self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
|
||||
self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)
|
||||
|
||||
agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
|
||||
agent.description += "\n" + DESCRIPTION_MESSAGE
|
||||
|
||||
def _image_gen_reply(
|
||||
self,
|
||||
recipient: ConversableAgent,
|
||||
messages: Optional[List[Dict]],
|
||||
sender: Optional[Agent] = None,
|
||||
config: Optional[Any] = None,
|
||||
) -> Tuple[bool, Union[str, Dict, None]]:
|
||||
if messages is None:
|
||||
return False, None
|
||||
|
||||
last_message = code_utils.content_str(messages[-1]["content"])
|
||||
|
||||
if not last_message:
|
||||
return False, None
|
||||
|
||||
if self._should_generate_image(last_message):
|
||||
prompt = self._extract_prompt(last_message)
|
||||
|
||||
image = self._cache_get(prompt)
|
||||
if image is None:
|
||||
image = self._image_generator.generate_image(prompt)
|
||||
self._cache_set(prompt, image)
|
||||
|
||||
return True, self._generate_content_message(prompt, image)
|
||||
|
||||
else:
|
||||
return False, None
|
||||
|
||||
def _should_generate_image(self, message: str) -> bool:
|
||||
assert self._text_analyzer is not None
|
||||
|
||||
instructions = """
|
||||
Does any part of the TEXT ask the agent to generate an image?
|
||||
The TEXT must explicitly mention that the image must be generated.
|
||||
Answer with just one word, yes or no.
|
||||
"""
|
||||
analysis = self._text_analyzer.analyze_text(message, instructions)
|
||||
|
||||
return "yes" in self._extract_analysis(analysis).lower()
|
||||
|
||||
def _extract_prompt(self, last_message) -> str:
|
||||
assert self._text_analyzer is not None
|
||||
|
||||
analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
|
||||
return self._extract_analysis(analysis)
|
||||
|
||||
def _cache_get(self, prompt: str) -> Optional[Image]:
|
||||
if self._cache:
|
||||
key = self._image_generator.cache_key(prompt)
|
||||
cached_value = self._cache.get(key)
|
||||
|
||||
if cached_value:
|
||||
return img_utils.get_pil_image(cached_value)
|
||||
|
||||
def _cache_set(self, prompt: str, image: Image):
|
||||
if self._cache:
|
||||
key = self._image_generator.cache_key(prompt)
|
||||
self._cache.set(key, img_utils.pil_to_data_uri(image))
|
||||
|
||||
def _extract_analysis(self, analysis: Union[str, Dict, None]) -> str:
|
||||
if isinstance(analysis, Dict):
|
||||
return code_utils.content_str(analysis["content"])
|
||||
else:
|
||||
return code_utils.content_str(analysis)
|
||||
|
||||
def _generate_content_message(self, prompt: str, image: Image) -> Dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
|
||||
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
### Helpers
|
||||
def _validate_resolution_format(resolution: str):
|
||||
"""Checks if a string is in a valid resolution format (e.g., "1024x768")."""
|
||||
pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
|
||||
matched_resolution = re.match(pattern, resolution)
|
||||
if matched_resolution is None:
|
||||
raise ValueError(f"Invalid resolution format: {resolution}")
|
||||
|
||||
|
||||
def _validate_dalle_model(model: str):
|
||||
if model not in ["dall-e-3", "dall-e-2"]:
|
||||
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,224 @@
|
|||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Tuple
|
||||
import pytest
|
||||
import sys
|
||||
from autogen import code_utils
|
||||
from autogen.agentchat.conversable_agent import ConversableAgent
|
||||
from autogen.agentchat.user_proxy_agent import UserProxyAgent
|
||||
from autogen.cache.cache import Cache
|
||||
from autogen.oai import openai_utils
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
from autogen.agentchat.contrib.capabilities import generate_images
|
||||
from autogen.agentchat.contrib.img_utils import get_pil_image
|
||||
except ImportError:
|
||||
skip_requirement = True
|
||||
else:
|
||||
skip_requirement = False
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402
|
||||
|
||||
filter_dict = {"model": ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]}
|
||||
|
||||
RESOLUTIONS = ["256x256", "512x512", "1024x1024"]
|
||||
QUALITIES = ["standard", "hd"]
|
||||
PROMPTS = [
|
||||
"Generate an image of a robot holding a 'I Love Autogen' sign",
|
||||
"Generate an image of a dog holding a 'I Love Autogen' sign",
|
||||
]
|
||||
|
||||
|
||||
class _TestImageGenerator:
|
||||
def __init__(self, image):
|
||||
self._image = image
|
||||
|
||||
def generate_image(self, prompt: str):
|
||||
return self._image
|
||||
|
||||
def cache_key(self, prompt: str):
|
||||
return prompt
|
||||
|
||||
|
||||
def create_test_agent(name: str = "test_agent", default_auto_reply: str = "") -> ConversableAgent:
|
||||
return ConversableAgent(name=name, llm_config=False, default_auto_reply=default_auto_reply)
|
||||
|
||||
|
||||
def dalle_image_generator(dalle_config: Dict[str, Any], resolution: str, quality: str):
|
||||
return generate_images.DalleImageGenerator(dalle_config, resolution=resolution, quality=quality, num_images=1)
|
||||
|
||||
|
||||
def api_key():
|
||||
return MOCK_OPEN_AI_API_KEY if skip_openai else os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dalle_config() -> Dict[str, Any]:
|
||||
config_list = openai_utils.config_list_from_models(model_list=["dall-e-2"], exclude="aoai")
|
||||
if not config_list:
|
||||
config_list = [{"model": "dall-e-2", "api_key": api_key()}]
|
||||
return {"config_list": config_list, "timeout": 120, "cache_seed": None}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpt3_config() -> Dict[str, Any]:
|
||||
config_list = [
|
||||
{
|
||||
"model": "gpt-35-turbo-16k",
|
||||
"api_key": api_key(),
|
||||
},
|
||||
{
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
"api_key": api_key(),
|
||||
},
|
||||
]
|
||||
return {"config_list": config_list, "timeout": 120, "cache_seed": None}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_gen_capability():
|
||||
image_generator = _TestImageGenerator(Image.new("RGB", (256, 256)))
|
||||
return generate_images.ImageGeneration(image_generator)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_openai, reason="Requested to skip.")
|
||||
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
||||
def test_dalle_image_generator(dalle_config: Dict[str, Any]):
|
||||
"""Tests DalleImageGenerator capability to generate images by calling the OpenAI API."""
|
||||
dalle_generator = dalle_image_generator(dalle_config, RESOLUTIONS[0], QUALITIES[0])
|
||||
image = dalle_generator.generate_image(PROMPTS[0])
|
||||
|
||||
assert isinstance(image, Image.Image)
|
||||
|
||||
|
||||
# Using cartesian product to generate all possible combinations of resolution, quality, and prompt
|
||||
@pytest.mark.parametrize("gen_config_1", itertools.product(RESOLUTIONS, QUALITIES, PROMPTS))
|
||||
@pytest.mark.parametrize("gen_config_2", itertools.product(RESOLUTIONS, QUALITIES, PROMPTS))
|
||||
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
||||
def test_dalle_image_generator_cache_key(
|
||||
dalle_config: Dict[str, Any], gen_config_1: Tuple[str, str, str], gen_config_2: Tuple[str, str, str]
|
||||
):
|
||||
"""Tests if DalleImageGenerator creates unique cache keys.
|
||||
|
||||
Args:
|
||||
dalle_config: The LLM config for the DalleImageGenerator.
|
||||
gen_config_1: A tuple containing the resolution, quality, and prompt for the first image generator.
|
||||
gen_config_2: A tuple containing the resolution, quality, and prompt for the second image generator.
|
||||
"""
|
||||
|
||||
dalle_generator_1 = dalle_image_generator(dalle_config, resolution=gen_config_1[0], quality=gen_config_1[1])
|
||||
dalle_generator_2 = dalle_image_generator(dalle_config, resolution=gen_config_2[0], quality=gen_config_2[1])
|
||||
|
||||
cache_key_1 = dalle_generator_1.cache_key(gen_config_1[2])
|
||||
cache_key_2 = dalle_generator_2.cache_key(gen_config_2[2])
|
||||
|
||||
if gen_config_1 == gen_config_2:
|
||||
assert cache_key_1 == cache_key_2
|
||||
else:
|
||||
assert cache_key_1 != cache_key_2
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
||||
def test_image_generation_capability_positive(monkeypatch, image_gen_capability):
|
||||
"""Tests ImageGeneration capability to generate images by calling the ImageGenerator.
|
||||
|
||||
This tests if the message is asking the agent to generate an image.
|
||||
"""
|
||||
auto_reply = "Didn't need to generate an image."
|
||||
|
||||
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent to make API calls
|
||||
# Improves reproducibility and falkiness of the test
|
||||
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: True)
|
||||
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
|
||||
|
||||
user = UserProxyAgent("user", human_input_mode="NEVER")
|
||||
agent = create_test_agent(default_auto_reply=auto_reply)
|
||||
image_gen_capability.add_to_agent(agent)
|
||||
|
||||
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
|
||||
last_message = agent.last_message()
|
||||
|
||||
assert last_message
|
||||
|
||||
processed_message = code_utils.content_str(last_message["content"])
|
||||
|
||||
assert "<image>" in processed_message
|
||||
assert auto_reply not in processed_message
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
||||
def test_image_generation_capability_negative(monkeypatch, image_gen_capability):
|
||||
"""Tests ImageGeneration capability to generate images by calling the ImageGenerator.
|
||||
|
||||
This tests if the message is not asking the agent to generate an image.
|
||||
"""
|
||||
auto_reply = "Didn't need to generate an image."
|
||||
|
||||
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent making API calls.
|
||||
# Improves reproducibility and flakiness of the test.
|
||||
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: False)
|
||||
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
|
||||
|
||||
user = UserProxyAgent("user", human_input_mode="NEVER")
|
||||
agent = ConversableAgent("test_agent", llm_config=False, default_auto_reply=auto_reply)
|
||||
image_gen_capability.add_to_agent(agent)
|
||||
|
||||
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
|
||||
last_message = agent.last_message()
|
||||
|
||||
assert last_message
|
||||
|
||||
processed_message = code_utils.content_str(last_message["content"])
|
||||
|
||||
assert "<image>" not in processed_message
|
||||
assert auto_reply == processed_message
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip_requirement, reason="Dependencies are not installed.")
|
||||
def test_image_generation_capability_cache(monkeypatch):
|
||||
"""Tests ImageGeneration capability to cache the generated images."""
|
||||
test_image_size = (256, 256)
|
||||
|
||||
# Patching the _should_generate_image and _extract_prompt methods to avoid TextAnalyzerAgent making API calls.
|
||||
monkeypatch.setattr(generate_images.ImageGeneration, "_should_generate_image", lambda _, __: True)
|
||||
monkeypatch.setattr(generate_images.ImageGeneration, "_extract_prompt", lambda _, __: PROMPTS[0])
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cache = Cache.disk(cache_path_root=temp_dir)
|
||||
|
||||
user = UserProxyAgent("user", human_input_mode="NEVER")
|
||||
agent = create_test_agent()
|
||||
|
||||
test_image_generator = _TestImageGenerator(Image.new("RGB", test_image_size))
|
||||
image_gen_capability = generate_images.ImageGeneration(test_image_generator, cache=cache)
|
||||
image_gen_capability.add_to_agent(agent)
|
||||
|
||||
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
|
||||
|
||||
# Checking if the image has been cached by creating a new agent with a different image generator.
|
||||
agent = create_test_agent(name="test_agent_2")
|
||||
test_image_generator = _TestImageGenerator(Image.new("RGB", (512, 512)))
|
||||
image_gen_capability = generate_images.ImageGeneration(test_image_generator, cache=cache)
|
||||
image_gen_capability.add_to_agent(agent)
|
||||
|
||||
user.send(message=PROMPTS[0], recipient=agent, request_reply=True, silent=True)
|
||||
|
||||
last_message = agent.last_message()
|
||||
|
||||
assert last_message
|
||||
|
||||
image_dict = [image for image in last_message["content"] if image["type"] == "image_url"]
|
||||
image = get_pil_image(image_dict[0]["image_url"]["url"])
|
||||
|
||||
assert image.size == test_image_size
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dalle_image_generator(
|
||||
dalle_config={"config_list": openai_utils.config_list_from_models(model_list=["dall-e-2"], exclude="aoai")}
|
||||
)
|
Loading…
Reference in New Issue