Message "content" now supports both `str` and `List` in Agents (#713)

* Change "content" type in Conversable Agent

* content and system_message support str and List
Update for all other agents

* Content_str now also takes None as input

* Group Chat now works with LMM too

* Style: newline for import in Conversable Agentt

* Add test for gourpchat + lmm

* Resolve comments
1. Undo AssistantAgent changes
2. Modify the asserts and raises in `content_str` function and update
test accordingly.

* Undo AssistantAgent

* Update comments and add assertion for LMM

* Typo fix in docstring for content_str

* Remove “None” out conversable_agent.py

* Lint message to dict in multimodal_conversable_agent.py

* Address lint issues

* linting

* Move lmm test into contrib test

* Resolve 2 comments

* Move img_utils into contrib folder

* Resolve img_utils path issues
This commit is contained in:
Beibin Li 2023-12-02 16:40:50 -09:00 committed by GitHub
parent 77e1d28c1b
commit c19f234149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 362 additions and 198 deletions

View File

@ -1,60 +0,0 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: ContribTests
on:
pull_request:
branches: ['main', 'dev/v0.2']
paths:
- 'autogen/img_utils.py'
- 'autogen/agentchat/contrib/multimodal_conversable_agent.py'
- 'autogen/agentchat/contrib/llava_agent.py'
- 'test/test_img_utils.py'
- 'test/agentchat/contrib/test_lmm.py'
- 'test/agentchat/contrib/test_llava.py'
- '.github/workflows/lmm-test.yml'
- 'setup.py'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
LMMTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for LMM
run: |
pip install -e .[lmm]
pip uninstall -y openai
- name: Test LMM and LLaVA
run: |
pytest test/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.10'
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests

View File

@ -136,3 +136,40 @@ jobs:
- name: Test TeachableAgent - name: Test TeachableAgent
run: | run: |
pytest test/agentchat/contrib/test_teachable_agent.py pytest test/agentchat/contrib/test_teachable_agent.py
LMMTest:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for LMM
run: |
pip install -e .[lmm]
pip uninstall -y openai
- name: Test LMM and LLaVA
run: |
pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.10'
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests

1
.gitignore vendored
View File

@ -167,6 +167,7 @@ wolfram.txt
# DB on disk for TeachableAgent # DB on disk for TeachableAgent
tmp/ tmp/
test/my_tmp/*
# Storage for the AgentEval output # Storage for the AgentEval output
test/test_files/agenteval-in-out/out/ test/test_files/agenteval-in-out/out/

View File

@ -1,6 +1,7 @@
from .conversable_agent import ConversableAgent
from typing import Callable, Dict, Literal, Optional, Union from typing import Callable, Dict, Literal, Optional, Union
from .conversable_agent import ConversableAgent
class AssistantAgent(ConversableAgent): class AssistantAgent(ConversableAgent):
"""(In preview) Assistant agent, designed to solve a task with LLM. """(In preview) Assistant agent, designed to solve a task with LLM.

View File

@ -10,9 +10,9 @@ import requests
from regex import R from regex import R
from autogen.agentchat.agent import Agent from autogen.agentchat.agent import Agent
from autogen.agentchat.contrib.img_utils import get_image_data, llava_formater
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.code_utils import content_str from autogen.code_utils import content_str
from autogen.img_utils import get_image_data, llava_formater
try: try:
from termcolor import colored from termcolor import colored

View File

@ -1,8 +1,9 @@
import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from autogen import OpenAIWrapper from autogen import OpenAIWrapper
from autogen.agentchat import Agent, ConversableAgent from autogen.agentchat import Agent, ConversableAgent
from autogen.img_utils import gpt4v_formatter from autogen.agentchat.contrib.img_utils import gpt4v_formatter
try: try:
from termcolor import colored from termcolor import colored
@ -41,19 +42,14 @@ class MultimodalConversableAgent(ConversableAgent):
*args, *args,
**kwargs, **kwargs,
) )
# call the setter to handle special format.
self.update_system_message(system_message) self.update_system_message(system_message)
self._is_termination_msg = ( self._is_termination_msg = (
is_termination_msg is_termination_msg
if is_termination_msg is not None if is_termination_msg is not None
else (lambda x: any([item["text"] == "TERMINATE" for item in x.get("content") if item["type"] == "text"])) else (lambda x: content_str(x.get("content")) == "TERMINATE")
) )
@property
def system_message(self) -> List:
"""Return the system message."""
return self._oai_system_message[0]["content"]
def update_system_message(self, system_message: Union[Dict, List, str]): def update_system_message(self, system_message: Union[Dict, List, str]):
"""Update the system message. """Update the system message.
@ -64,44 +60,29 @@ class MultimodalConversableAgent(ConversableAgent):
self._oai_system_message[0]["role"] = "system" self._oai_system_message[0]["role"] = "system"
@staticmethod @staticmethod
def _message_to_dict(message: Union[Dict, List, str]): def _message_to_dict(message: Union[Dict, List, str]) -> Dict:
"""Convert a message to a dictionary. """Convert a message to a dictionary. This implementation
handles the GPT-4V formatting for easier prompts.
The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary. The message can be a string, a dictionary, or a list of dictionaries:
- If it's a string, it will be cast into a list and placed in the 'content' field.
- If it's a list, it will be directly placed in the 'content' field.
- If it's a dictionary, it is already in message dict format. The 'content' field of this dictionary
will be processed using the gpt4v_formatter.
""" """
if isinstance(message, str): if isinstance(message, str):
return {"content": gpt4v_formatter(message)} return {"content": gpt4v_formatter(message)}
if isinstance(message, list): if isinstance(message, list):
return {"content": message} return {"content": message}
else: if isinstance(message, dict):
assert "content" in message, "The message dict must have a `content` field"
if isinstance(message["content"], str):
message = copy.deepcopy(message)
message["content"] = gpt4v_formatter(message["content"])
try:
content_str(message["content"])
except (TypeError, ValueError) as e:
print("The `content` field should be compatible with the content_str function!")
raise e
return message return message
raise ValueError(f"Unsupported message type: {type(message)}")
def _print_received_message(self, message: Union[Dict, str], sender: Agent):
# print the message received
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
if message.get("role") == "function":
func_print = f"***** Response from calling function \"{message['name']}\" *****"
print(colored(func_print, "green"), flush=True)
print(content_str(message["content"]), flush=True)
print(colored("*" * len(func_print), "green"), flush=True)
else:
content = message.get("content")
if content is not None:
if "context" in message:
content = OpenAIWrapper.instantiate(
content,
message["context"],
self.llm_config and self.llm_config.get("allow_format_str_template", False),
)
print(content_str(content), flush=True)
if "function_call" in message:
func_print = f"***** Suggested function Call: {message['function_call'].get('name', '(No function name found)')} *****"
print(colored(func_print, "green"), flush=True)
print(
"Arguments: \n",
message["function_call"].get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
print(colored("*" * len(func_print), "green"), flush=True)
print("\n", "-" * 80, flush=True, sep="")

View File

@ -1,18 +1,14 @@
import asyncio import asyncio
from collections import defaultdict
import copy import copy
import json import json
import logging import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from autogen import OpenAIWrapper from autogen import OpenAIWrapper
from autogen.code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang
from .agent import Agent from .agent import Agent
from autogen.code_utils import (
DEFAULT_MODEL,
UNKNOWN,
execute_code,
extract_code,
infer_lang,
)
try: try:
from termcolor import colored from termcolor import colored
@ -50,7 +46,7 @@ class ConversableAgent(Agent):
def __init__( def __init__(
self, self,
name: str, name: str,
system_message: Optional[str] = "You are a helpful AI Assistant.", system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.",
is_termination_msg: Optional[Callable[[Dict], bool]] = None, is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None, max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "TERMINATE", human_input_mode: Optional[str] = "TERMINATE",
@ -62,7 +58,7 @@ class ConversableAgent(Agent):
""" """
Args: Args:
name (str): name of the agent. name (str): name of the agent.
system_message (str): system message for the ChatCompletion inference. system_message (str or list): system message for the ChatCompletion inference.
is_termination_msg (function): a function that takes a message in the form of a dictionary is_termination_msg (function): a function that takes a message in the form of a dictionary
and returns a boolean value indicating if this received message is a termination message. and returns a boolean value indicating if this received message is a termination message.
The dict can contain the following keys: "content", "role", "name", "function_call". The dict can contain the following keys: "content", "role", "name", "function_call".
@ -105,8 +101,11 @@ class ConversableAgent(Agent):
self._oai_messages = defaultdict(list) self._oai_messages = defaultdict(list)
self._oai_system_message = [{"content": system_message, "role": "system"}] self._oai_system_message = [{"content": system_message, "role": "system"}]
self._is_termination_msg = ( self._is_termination_msg = (
is_termination_msg if is_termination_msg is not None else (lambda x: x.get("content") == "TERMINATE") is_termination_msg
if is_termination_msg is not None
else (lambda x: content_str(x.get("content")) == "TERMINATE")
) )
if llm_config is False: if llm_config is False:
self.llm_config = False self.llm_config = False
self.client = None self.client = None
@ -190,15 +189,15 @@ class ConversableAgent(Agent):
) )
@property @property
def system_message(self): def system_message(self) -> Union[str, List]:
"""Return the system message.""" """Return the system message."""
return self._oai_system_message[0]["content"] return self._oai_system_message[0]["content"]
def update_system_message(self, system_message: str): def update_system_message(self, system_message: Union[str, List]):
"""Update the system message. """Update the system message.
Args: Args:
system_message (str): system message for the ChatCompletion inference. system_message (str or List): system message for the ChatCompletion inference.
""" """
self._oai_system_message[0]["content"] = system_message self._oai_system_message[0]["content"] = system_message
@ -258,7 +257,7 @@ class ConversableAgent(Agent):
return None if self._code_execution_config is False else self._code_execution_config.get("use_docker") return None if self._code_execution_config is False else self._code_execution_config.get("use_docker")
@staticmethod @staticmethod
def _message_to_dict(message: Union[Dict, str]): def _message_to_dict(message: Union[Dict, str]) -> Dict:
"""Convert a message to a dictionary. """Convert a message to a dictionary.
The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary. The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary.
@ -314,7 +313,7 @@ class ConversableAgent(Agent):
Args: Args:
message (dict or str): message to be sent. message (dict or str): message to be sent.
The message could contain the following fields: The message could contain the following fields:
- content (str): Required, the content of the message. (Can be None) - content (str or List): Required, the content of the message. (Can be None)
- function_call (str): the name of the function to be called. - function_call (str): the name of the function to be called.
- name (str): the name of the function to be called. - name (str): the name of the function to be called.
- role (str): the role of the message, any role that is not "function" - role (str): the role of the message, any role that is not "function"
@ -363,7 +362,7 @@ class ConversableAgent(Agent):
Args: Args:
message (dict or str): message to be sent. message (dict or str): message to be sent.
The message could contain the following fields: The message could contain the following fields:
- content (str): Required, the content of the message. (Can be None) - content (str or List): Required, the content of the message. (Can be None)
- function_call (str): the name of the function to be called. - function_call (str): the name of the function to be called.
- name (str): the name of the function to be called. - name (str): the name of the function to be called.
- role (str): the role of the message, any role that is not "function" - role (str): the role of the message, any role that is not "function"
@ -419,7 +418,7 @@ class ConversableAgent(Agent):
message["context"], message["context"],
self.llm_config and self.llm_config.get("allow_format_str_template", False), self.llm_config and self.llm_config.get("allow_format_str_template", False),
) )
print(content, flush=True) print(content_str(content), flush=True)
if "function_call" in message: if "function_call" in message:
function_call = dict(message["function_call"]) function_call = dict(message["function_call"])
func_print = ( func_print = (
@ -435,7 +434,7 @@ class ConversableAgent(Agent):
print(colored("*" * len(func_print), "green"), flush=True) print(colored("*" * len(func_print), "green"), flush=True)
print("\n", "-" * 80, flush=True, sep="") print("\n", "-" * 80, flush=True, sep="")
def _process_received_message(self, message, sender, silent): def _process_received_message(self, message: Union[Dict, str], sender: Agent, silent: bool):
message = self._message_to_dict(message) message = self._message_to_dict(message)
# When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.) # When the agent receives a message, the role of the message is "user". (If 'role' exists and is 'function', it will remain unchanged.)
valid = self._append_oai_message(message, "user", sender) valid = self._append_oai_message(message, "user", sender)
@ -681,7 +680,7 @@ class ConversableAgent(Agent):
messages: Optional[List[Dict]] = None, messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None, sender: Optional[Agent] = None,
config: Optional[Any] = None, config: Optional[Any] = None,
): ) -> Tuple[bool, Union[Dict, None]]:
"""Generate a reply using function call.""" """Generate a reply using function call."""
if config is None: if config is None:
config = self config = self
@ -698,7 +697,7 @@ class ConversableAgent(Agent):
messages: Optional[List[Dict]] = None, messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None, sender: Optional[Agent] = None,
config: Optional[Any] = None, config: Optional[Any] = None,
): ) -> Tuple[bool, Union[Dict, None]]:
"""Generate a reply using async function call.""" """Generate a reply using async function call."""
if config is None: if config is None:
config = self config = self
@ -720,8 +719,26 @@ class ConversableAgent(Agent):
messages: Optional[List[Dict]] = None, messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None, sender: Optional[Agent] = None,
config: Optional[Any] = None, config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]: ) -> Tuple[bool, Union[str, None]]:
"""Check if the conversation should be terminated, and if human reply is provided.""" """Check if the conversation should be terminated, and if human reply is provided.
This method checks for conditions that require the conversation to be terminated, such as reaching
a maximum number of consecutive auto-replies or encountering a termination message. Additionally,
it prompts for and processes human input based on the configured human input mode, which can be
'ALWAYS', 'NEVER', or 'TERMINATE'. The method also manages the consecutive auto-reply counter
for the conversation and prints relevant messages based on the human input received.
Args:
- messages (Optional[List[Dict]]): A list of message dictionaries, representing the conversation history.
- sender (Optional[Agent]): The agent object representing the sender of the message.
- config (Optional[Any]): Configuration object, defaults to the current instance if not provided.
Returns:
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
should be terminated, and a human reply which can be a string, a dictionary, or None.
"""
# Function implementation...
if config is None: if config is None:
config = self config = self
if messages is None: if messages is None:
@ -791,8 +808,24 @@ class ConversableAgent(Agent):
messages: Optional[List[Dict]] = None, messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None, sender: Optional[Agent] = None,
config: Optional[Any] = None, config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]: ) -> Tuple[bool, Union[str, None]]:
"""(async) Check if the conversation should be terminated, and if human reply is provided.""" """(async) Check if the conversation should be terminated, and if human reply is provided.
This method checks for conditions that require the conversation to be terminated, such as reaching
a maximum number of consecutive auto-replies or encountering a termination message. Additionally,
it prompts for and processes human input based on the configured human input mode, which can be
'ALWAYS', 'NEVER', or 'TERMINATE'. The method also manages the consecutive auto-reply counter
for the conversation and prints relevant messages based on the human input received.
Args:
- messages (Optional[List[Dict]]): A list of message dictionaries, representing the conversation history.
- sender (Optional[Agent]): The agent object representing the sender of the message.
- config (Optional[Any]): Configuration object, defaults to the current instance if not provided.
Returns:
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
should be terminated, and a human reply which can be a string, a dictionary, or None.
"""
if config is None: if config is None:
config = self config = self
if messages is None: if messages is None:
@ -962,8 +995,20 @@ class ConversableAgent(Agent):
return reply return reply
return self._default_auto_reply return self._default_auto_reply
def _match_trigger(self, trigger, sender): def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Agent) -> bool:
"""Check if the sender matches the trigger.""" """Check if the sender matches the trigger.
Args:
- trigger (Union[None, str, type, Agent, Callable, List]): The condition to match against the sender.
Can be `None`, string, type, `Agent` instance, callable, or a list of these.
- sender (Agent): The sender object or type to be matched against the trigger.
Returns:
- bool: Returns `True` if the sender matches the trigger, otherwise `False`.
Raises:
- ValueError: If the trigger type is unsupported.
"""
if trigger is None: if trigger is None:
return sender is None return sender is None
elif isinstance(trigger, str): elif isinstance(trigger, str):
@ -971,9 +1016,12 @@ class ConversableAgent(Agent):
elif isinstance(trigger, type): elif isinstance(trigger, type):
return isinstance(sender, trigger) return isinstance(sender, trigger)
elif isinstance(trigger, Agent): elif isinstance(trigger, Agent):
# return True if the sender is the same type (class) as the trigger
return trigger == sender return trigger == sender
elif isinstance(trigger, Callable): elif isinstance(trigger, Callable):
return trigger(sender) rst = trigger(sender)
assert rst in [True, False], f"trigger {trigger} must return a boolean value."
return rst
elif isinstance(trigger, list): elif isinstance(trigger, list):
return any(self._match_trigger(t, sender) for t in trigger) return any(self._match_trigger(t, sender) for t in trigger)
else: else:
@ -1095,7 +1143,7 @@ class ConversableAgent(Agent):
result.append(char) result.append(char)
return "".join(result) return "".join(result)
def execute_function(self, func_call): def execute_function(self, func_call) -> Tuple[bool, Dict[str, str]]:
"""Execute a function call and return the result. """Execute a function call and return the result.
Override this function to modify the way to execute a function call. Override this function to modify the way to execute a function call.
@ -1195,7 +1243,10 @@ class ConversableAgent(Agent):
"""Generate the initial message for the agent. """Generate the initial message for the agent.
Override this function to customize the initial message based on user's request. Override this function to customize the initial message based on user's request.
If not overridden, "message" needs to be provided in the context. If not overriden, "message" needs to be provided in the context.
Args:
**context: any context information, and "message" parameter needs to be provided.
""" """
return context["message"] return context["message"]

View File

@ -1,9 +1,11 @@
import logging import logging
import sys
import random import random
import re
import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import re
from ..code_utils import content_str
from .agent import Agent from .agent import Agent
from .conversable_agent import ConversableAgent from .conversable_agent import ConversableAgent
@ -50,6 +52,14 @@ class GroupChat:
"""Reset the group chat.""" """Reset the group chat."""
self.messages.clear() self.messages.clear()
def append(self, message: Dict):
"""Append a message to the group chat.
We cast the content to str here so that it can be managed by text-based
model.
"""
message["content"] = content_str(message["content"])
self.messages.append(message)
def agent_by_name(self, name: str) -> Agent: def agent_by_name(self, name: str) -> Agent:
"""Returns the agent with a given name.""" """Returns the agent with a given name."""
return self.agents[self.agent_names.index(name)] return self.agents[self.agent_names.index(name)]
@ -64,7 +74,7 @@ class GroupChat:
if self.agents[(offset + i) % len(self.agents)] in agents: if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)] return self.agents[(offset + i) % len(self.agents)]
def select_speaker_msg(self, agents: List[Agent]): def select_speaker_msg(self, agents: List[Agent]) -> str:
"""Return the message for selecting the next speaker.""" """Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available: return f"""You are in a role play game. The following roles are available:
{self._participant_roles(agents)}. {self._participant_roles(agents)}.
@ -72,7 +82,7 @@ class GroupChat:
Read the following conversation. Read the following conversation.
Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.""" Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
def manual_select_speaker(self, agents: List[Agent]) -> Agent: def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]:
"""Manually select the next speaker.""" """Manually select the next speaker."""
print("Please select the next speaker from the following list:") print("Please select the next speaker from the following list:")
@ -190,19 +200,26 @@ Then select the next role from {[agent.name for agent in agents]} to play. Only
roles = [] roles = []
for agent in agents: for agent in agents:
if agent.system_message.strip() == "": if content_str(agent.system_message).strip() == "":
logger.warning( logger.warning(
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat." f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."
) )
roles.append(f"{agent.name}: {agent.system_message}") roles.append(f"{agent.name}: {agent.system_message}")
return "\n".join(roles) return "\n".join(roles)
def _mentioned_agents(self, message_content: str, agents: List[Agent]) -> Dict: def _mentioned_agents(self, message_content: Union[str, List], agents: List[Agent]) -> Dict:
""" """Counts the number of times each agent is mentioned in the provided message content.
Finds and counts agent mentions in the string message_content, taking word boundaries into account.
Returns: A dictionary mapping agent names to mention counts (to be included, at least one mention must occur) Args:
message_content (Union[str, List]): The content of the message, either as a single string or a list of strings.
agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content.
Returns:
Dict: a counter for mentioned agents.
""" """
# Cast message content to str
message_content = content_str(message_content)
mentions = dict() mentions = dict()
for agent in agents: for agent in agents:
regex = ( regex = (
@ -224,7 +241,7 @@ class GroupChatManager(ConversableAgent):
# unlimited consecutive auto reply by default # unlimited consecutive auto reply by default
max_consecutive_auto_reply: Optional[int] = sys.maxsize, max_consecutive_auto_reply: Optional[int] = sys.maxsize,
human_input_mode: Optional[str] = "NEVER", human_input_mode: Optional[str] = "NEVER",
system_message: Optional[str] = "Group chat manager.", system_message: Optional[Union[str, List]] = "Group chat manager.",
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -256,12 +273,12 @@ class GroupChatManager(ConversableAgent):
# set the name to speaker's name if the role is not function # set the name to speaker's name if the role is not function
if message["role"] != "function": if message["role"] != "function":
message["name"] = speaker.name message["name"] = speaker.name
groupchat.messages.append(message)
groupchat.append(message)
if self._is_termination_msg(message): if self._is_termination_msg(message):
# The conversation is over # The conversation is over
break break
# broadcast the message to all agents except the speaker # broadcast the message to all agents except the speaker
for agent in groupchat.agents: for agent in groupchat.agents:
if agent != speaker: if agent != speaker:
@ -306,7 +323,8 @@ class GroupChatManager(ConversableAgent):
# set the name to speaker's name if the role is not function # set the name to speaker's name if the role is not function
if message["role"] != "function": if message["role"] != "function":
message["name"] = speaker.name message["name"] = speaker.name
groupchat.messages.append(message)
groupchat.append(message)
if self._is_termination_msg(message): if self._is_termination_msg(message):
# The conversation is over # The conversation is over

View File

@ -1,5 +1,6 @@
from typing import Callable, Dict, List, Literal, Optional, Union
from .conversable_agent import ConversableAgent from .conversable_agent import ConversableAgent
from typing import Callable, Dict, Literal, Optional, Union
class UserProxyAgent(ConversableAgent): class UserProxyAgent(ConversableAgent):
@ -25,7 +26,7 @@ class UserProxyAgent(ConversableAgent):
code_execution_config: Optional[Union[Dict, Literal[False]]] = None, code_execution_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "", default_auto_reply: Optional[Union[str, Dict, None]] = "",
llm_config: Optional[Union[Dict, Literal[False]]] = False, llm_config: Optional[Union[Dict, Literal[False]]] = False,
system_message: Optional[str] = "", system_message: Optional[Union[str, List]] = "",
): ):
""" """
Args: Args:
@ -66,7 +67,7 @@ class UserProxyAgent(ConversableAgent):
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options. for available options.
Default to false, which disables llm-based auto reply. Default to false, which disables llm-based auto reply.
system_message (str): system message for ChatCompletion inference. system_message (str or List): system message for ChatCompletion inference.
Only used when llm_config is not False. Use it to reprogram the agent. Only used when llm_config is not False. Use it to reprogram the agent.
""" """
super().__init__( super().__init__(

View File

@ -38,16 +38,44 @@ PATH_SEPARATOR = WIN32 and "\\" or "/"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def content_str(content: Union[str, List]) -> str: def content_str(content: Union[str, List, None]) -> str:
if type(content) is str: """Converts `content` into a string format.
This function processes content that may be a string, a list of mixed text and image URLs, or None,
and converts it into a string. Text is directly appended to the result string, while image URLs are
represented by a placeholder image token. If the content is None, an empty string is returned.
Args:
- content (Union[str, List, None]): The content to be processed. Can be a string, a list of dictionaries
representing text and image URLs, or None.
Returns:
str: A string representation of the input content. Image URLs are replaced with an image token.
Note:
- The function expects each dictionary in the list to have a "type" key that is either "text" or "image_url".
For "text" type, the "text" key's value is appended to the result. For "image_url", an image token is appended.
- This function is useful for handling content that may include both text and image references, especially
in contexts where images need to be represented as placeholders.
"""
if content is None:
return ""
if isinstance(content, str):
return content return content
if not isinstance(content, list):
raise TypeError(f"content must be None, str, or list, but got {type(content)}")
rst = "" rst = ""
for item in content: for item in content:
if not isinstance(item, dict):
raise TypeError("Wrong content format: every element should be dict if the content is a list.")
assert "type" in item, "Wrong content format. Missing 'type' key in content's dict."
if item["type"] == "text": if item["type"] == "text":
rst += item["text"] rst += item["text"]
else: elif item["type"] == "image_url":
assert isinstance(item, dict) and item["type"] == "image_url", "Wrong content format."
rst += "<image>" rst += "<image>"
else:
raise ValueError(f"Wrong content format: unknown type {item['type']} within the content")
return rst return rst

View File

@ -39,7 +39,7 @@
"import autogen\n", "import autogen\n",
"from autogen import AssistantAgent, Agent, UserProxyAgent, ConversableAgent\n", "from autogen import AssistantAgent, Agent, UserProxyAgent, ConversableAgent\n",
"\n", "\n",
"from autogen.img_utils import get_image_data, _to_pil\n", "from autogen.agentchat.contrib.img_utils import get_image_data, _to_pil\n",
"from termcolor import colored\n", "from termcolor import colored\n",
"import random" "import random"
] ]

View File

@ -91,38 +91,6 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"id": "57462351",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['openai']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Remove the `api_type` param as it is not needed for 4V\n",
"[config.pop(\"api_type\", None) for config in config_list_4v]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e23df0dd",
"metadata": {},
"outputs": [],
"source": [
"# image_agent._oai_messages[user_proxy]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "67157629", "id": "67157629",
"metadata": { "metadata": {
"scrolled": false "scrolled": false
@ -180,7 +148,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 4,
"id": "73a2b234", "id": "73a2b234",
"metadata": { "metadata": {
"scrolled": false "scrolled": false
@ -236,7 +204,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 5,
"id": "e8eca993", "id": "e8eca993",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -339,7 +307,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 6,
"id": "977b9017", "id": "977b9017",
"metadata": { "metadata": {
"scrolled": false "scrolled": false
@ -724,7 +692,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 7,
"id": "f0a58827", "id": "f0a58827",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -736,10 +704,100 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "b95bf449", "id": "c6206648",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
},
{
"cell_type": "markdown",
"id": "a95d87c2",
"metadata": {},
"source": [
"## Group Chat Example with Multimodal Agent"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "56bd5742",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mUser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"Describe the image:\n",
" <img https://th.bing.com/th/id/R.422068ce8af4e15b0634fe2540adea7a?rik=y4OcXBE%2fqutDOw&pid=ImgRaw&r=0>.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mimage-explainer-1\u001b[0m (to chat_manager):\n",
"\n",
"In a soft-focus world, a caramel-colored puppy with a coat of curly fur sits serenely, its innocent eyes gazing into the distance. Adorned with a collar that hosts a vibrant, multicolored bandana and a shiny tag engraved with the name \"Webster,\" the pup exudes a sense of youthful curiosity and gentle charm. Behind this bundle of joy, the muted backdrop of a home's interior whispers tales of comfort and domesticity, with a pair of black boots resting by the door, hinting at the comings and goings of human life amidst which this little creature finds its love and belonging.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mimage-explainer-2\u001b[0m (to chat_manager):\n",
"\n",
"The image shows a young, caramel-colored puppy with curly fur sitting on the floor. The puppy is wearing a blue collar with a colorful bandana and a tag that appears to have the name \"Webster\" engraved on it. In the background, there are black boots near a white door, suggesting an indoor, home setting. The focus is on the puppy, making the background appear softly blurred. The puppy's expression is gentle, with a hint of curiosity in its eyes.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mUser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mUser_proxy\u001b[0m (to chat_manager):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"agent1 = MultimodalConversableAgent(\n",
" name=\"image-explainer-1\",\n",
" max_consecutive_auto_reply=10,\n",
" llm_config={\"config_list\": config_list_4v, \"temperature\": 0.5, \"max_tokens\": 300},\n",
" system_message=\"Your image description is poetic and engaging.\",\n",
")\n",
"agent2 = MultimodalConversableAgent(\n",
" name=\"image-explainer-2\",\n",
" max_consecutive_auto_reply=10,\n",
" llm_config={\"config_list\": config_list_4v, \"temperature\": 0.5, \"max_tokens\": 300},\n",
" system_message=\"Your image description is factual and to the point.\",\n",
")\n",
"\n",
"user_proxy = autogen.UserProxyAgent(\n",
" name=\"User_proxy\",\n",
" system_message=\"Ask both image explainer 1 and 2 for their description.\",\n",
" human_input_mode=\"TERMINATE\", # Try between ALWAYS, NEVER, and TERMINATE\n",
" max_consecutive_auto_reply=10,\n",
")\n",
"\n",
"# We set max_round to 5\n",
"groupchat = autogen.GroupChat(agents=[agent1, agent2, user_proxy], \n",
" messages=[], \n",
" max_round=5)\n",
"group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, \n",
" llm_config=gpt4_llm_config)\n",
"\n",
"user_proxy.initiate_chat(group_chat_manager,\n",
" message=f\"\"\"Describe the image:\n",
" <img https://th.bing.com/th/id/R.422068ce8af4e15b0634fe2540adea7a?rik=y4OcXBE%2fqutDOw&pid=ImgRaw&r=0>.\"\"\")"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -10,7 +10,7 @@ import requests
try: try:
from PIL import Image from PIL import Image
from autogen.img_utils import extract_img_paths, get_image_data, gpt4v_formatter, llava_formater from autogen.agentchat.contrib.img_utils import extract_img_paths, get_image_data, gpt4v_formatter, llava_formater
except ImportError: except ImportError:
skip = True skip = True
else: else:
@ -71,7 +71,7 @@ class TestLlavaFormater(unittest.TestCase):
result = llava_formater(prompt) result = llava_formater(prompt)
self.assertEqual(result, expected_output) self.assertEqual(result, expected_output)
@patch("autogen.img_utils.get_image_data") @patch("autogen.agentchat.contrib.img_utils.get_image_data")
def test_with_images(self, mock_get_image_data): def test_with_images(self, mock_get_image_data):
""" """
Test the llava_formater function with a prompt containing images. Test the llava_formater function with a prompt containing images.
@ -84,7 +84,7 @@ class TestLlavaFormater(unittest.TestCase):
result = llava_formater(prompt) result = llava_formater(prompt)
self.assertEqual(result, expected_output) self.assertEqual(result, expected_output)
@patch("autogen.img_utils.get_image_data") @patch("autogen.agentchat.contrib.img_utils.get_image_data")
def test_with_ordered_images(self, mock_get_image_data): def test_with_ordered_images(self, mock_get_image_data):
""" """
Test the llava_formater function with ordered image tokens. Test the llava_formater function with ordered image tokens.
@ -109,7 +109,7 @@ class TestGpt4vFormatter(unittest.TestCase):
result = gpt4v_formatter(prompt) result = gpt4v_formatter(prompt)
self.assertEqual(result, expected_output) self.assertEqual(result, expected_output)
@patch("autogen.img_utils.get_image_data") @patch("autogen.agentchat.contrib.img_utils.get_image_data")
def test_with_images(self, mock_get_image_data): def test_with_images(self, mock_get_image_data):
""" """
Test the gpt4v_formatter function with a prompt containing images. Test the gpt4v_formatter function with a prompt containing images.
@ -126,7 +126,7 @@ class TestGpt4vFormatter(unittest.TestCase):
result = gpt4v_formatter(prompt) result = gpt4v_formatter(prompt)
self.assertEqual(result, expected_output) self.assertEqual(result, expected_output)
@patch("autogen.img_utils.get_image_data") @patch("autogen.agentchat.contrib.img_utils.get_image_data")
def test_multiple_images(self, mock_get_image_data): def test_multiple_images(self, mock_get_image_data):
""" """
Test the gpt4v_formatter function with a prompt containing multiple images. Test the gpt4v_formatter function with a prompt containing multiple images.

View File

@ -79,5 +79,53 @@ class TestMultimodalConversableAgent(unittest.TestCase):
self.agent._print_received_message.assert_called_with(message_str, sender) self.agent._print_received_message.assert_called_with(message_str, sender)
@pytest.mark.skipif(skip, reason="Dependency not installed")
def test_group_chat_with_lmm():
"""
Tests the group chat functionality with two MultimodalConversable Agents.
Verifies that the chat is correctly limited by the max_round parameter.
Each agent is set to describe an image in a unique style, but the chat should not exceed the specified max_rounds.
"""
# Configuration parameters
max_round = 5
max_consecutive_auto_reply = 10
llm_config = False
# Creating two MultimodalConversable Agents with different descriptive styles
agent1 = MultimodalConversableAgent(
name="image-explainer-1",
max_consecutive_auto_reply=max_consecutive_auto_reply,
llm_config=llm_config,
system_message="Your image description is poetic and engaging.",
)
agent2 = MultimodalConversableAgent(
name="image-explainer-2",
max_consecutive_auto_reply=max_consecutive_auto_reply,
llm_config=llm_config,
system_message="Your image description is factual and to the point.",
)
# Creating a user proxy agent for initiating the group chat
user_proxy = autogen.UserProxyAgent(
name="User_proxy",
system_message="Ask both image explainer 1 and 2 for their description.",
human_input_mode="NEVER", # Options: 'ALWAYS' or 'NEVER'
max_consecutive_auto_reply=max_consecutive_auto_reply,
)
# Setting up the group chat
groupchat = autogen.GroupChat(agents=[agent1, agent2, user_proxy], messages=[], max_round=max_round)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
# Initiating the group chat and observing the number of rounds
user_proxy.initiate_chat(group_chat_manager, message=f"What do you see? <img {base64_encoded_image}>")
# Assertions to check if the number of rounds does not exceed max_round
assert all(len(arr) <= max_round for arr in agent1._oai_messages.values()), "Agent 1 exceeded max rounds"
assert all(len(arr) <= max_round for arr in agent2._oai_messages.values()), "Agent 2 exceeded max rounds"
assert all(len(arr) <= max_round for arr in user_proxy._oai_messages.values()), "User proxy exceeded max rounds"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -403,7 +403,7 @@ class TestContentStr(unittest.TestCase):
def test_invalid_content(self): def test_invalid_content(self):
content = [{"type": "text", "text": "hello"}, {"type": "wrong_type", "url": "http://example.com/image.png"}] content = [{"type": "text", "text": "hello"}, {"type": "wrong_type", "url": "http://example.com/image.png"}]
with self.assertRaises(AssertionError) as context: with self.assertRaises(ValueError) as context:
content_str(content) content_str(content)
self.assertIn("Wrong content format", str(context.exception)) self.assertIn("Wrong content format", str(context.exception))