mirror of https://github.com/microsoft/autogen.git
Added termination.
This commit is contained in:
parent
dc172b2d0b
commit
33e9670cb6
|
@ -35,7 +35,10 @@ from ._prompts import (
|
|||
ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT,
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT,
|
||||
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT,
|
||||
ORCHESTRATOR_FINAL_ANSWER_PROMPT,
|
||||
)
|
||||
|
||||
#class LedgerOrchestratorManager(BaseGroupChatManager):
|
||||
|
@ -48,6 +51,8 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
|
|||
participant_topic_types: List[str],
|
||||
participant_descriptions: List[str],
|
||||
model_client: ChatCompletionClient,
|
||||
max_rounds: int = 20,
|
||||
max_stalls: int = 3,
|
||||
):
|
||||
super().__init__(description="Group chat manager")
|
||||
self._group_topic_type = group_topic_type
|
||||
|
@ -64,9 +69,15 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
|
|||
|
||||
self._name = "orchestrator"
|
||||
self._model_client = model_client
|
||||
self._max_rounds = max_rounds
|
||||
self._max_stalls = max_stalls
|
||||
|
||||
self._task = None
|
||||
self._facts = None
|
||||
self._plan = None
|
||||
self._n_rounds = 0
|
||||
self._n_stalls = 0
|
||||
|
||||
self._team_description = "\n".join(
|
||||
[
|
||||
f"{topic_type}: {description}".strip()
|
||||
|
@ -88,6 +99,16 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
|
|||
def _get_progress_ledger_prompt(self, task: str, team: str, names: List[str]) -> str:
|
||||
return ORCHESTRATOR_PROGRESS_LEDGER_PROMPT.format(task=task, team=team, names=", ".join(names))
|
||||
|
||||
def _get_task_ledger_facts_update_prompt(self, task: str, facts: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT.format(task=task, facts=facts)
|
||||
|
||||
def _get_task_ledger_plan_update_prompt(self, team: str) -> str:
|
||||
return ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT.format(team=team)
|
||||
|
||||
def _get_final_answer_prompt(self, task: str) -> str:
|
||||
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)
|
||||
|
||||
|
||||
@event
|
||||
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
|
||||
"""Handle the start of a group chat by selecting a speaker to start the conversation."""
|
||||
|
@ -124,6 +145,7 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
|
|||
self._plan = response.content
|
||||
|
||||
# Kick things off
|
||||
self._n_stalls = 0
|
||||
await self._reenter_inner_loop()
|
||||
|
||||
@event
|
||||
|
@ -152,14 +174,23 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
|
|||
async def _reenter_inner_loop(self):
|
||||
# TODO: Reset the agents
|
||||
|
||||
# Broadcast the new plan
|
||||
# Prepare the ledger
|
||||
ledger_message = TextMessage(
|
||||
content=self._get_task_ledger_full_prompt(self._task, self._team_description, self._facts, self._plan),
|
||||
source=self._name
|
||||
)
|
||||
|
||||
self._message_thread.append(ledger_message) # My copy
|
||||
await self.publish_message( # Broadcast
|
||||
# Save my copy
|
||||
self._message_thread.append(ledger_message)
|
||||
|
||||
# Log it
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=ledger_message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(agent_response=Response(chat_message=ledger_message)),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
)
|
||||
|
@ -170,6 +201,12 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
|
|||
|
||||
async def _orchestrate_step(self) -> None:
|
||||
|
||||
# Check if we reached the maximum number of rounds
|
||||
if self._n_rounds > self._max_rounds:
|
||||
await self._prepare_final_answer("Max rounds reached.")
|
||||
return
|
||||
self._n_rounds += 1
|
||||
|
||||
# Update the progress ledger
|
||||
context = []
|
||||
for m in self._message_thread:
|
||||
|
@ -188,6 +225,23 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
|
|||
|
||||
progress_ledger = json.loads(response.content)
|
||||
|
||||
# Check for task completion
|
||||
if progress_ledger["is_request_satisfied"]["answer"]:
|
||||
await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"])
|
||||
return
|
||||
|
||||
# Check for stalling
|
||||
if progress_ledger["is_progress_being_made"]["answer"] or progress_ledger["is_in_loop"]["answer"]:
|
||||
self._n_stalls += 1
|
||||
else:
|
||||
self._n_stalls = max(0, self._n_stalls-1)
|
||||
|
||||
# Too much stalling
|
||||
if self._n_stalls >= self._max_stalls:
|
||||
await self._update_task_ledger()
|
||||
await self._reenter_inner_loop()
|
||||
return
|
||||
|
||||
# Broadcst the next step
|
||||
message = TextMessage(
|
||||
content=progress_ledger["instruction_or_question"]["answer"],
|
||||
|
@ -207,6 +261,73 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
|
|||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
)
|
||||
|
||||
# Request that it be completed
|
||||
# Request that the step be completed
|
||||
next_speaker = progress_ledger["next_speaker"]["answer"]
|
||||
await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker))
|
||||
|
||||
async def _update_task_ledger(self) -> None:
|
||||
|
||||
context = []
|
||||
for m in self._message_thread:
|
||||
if m.source == self._name:
|
||||
context.append(AssistantMessage(content=m.content, source=m.source))
|
||||
else:
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
|
||||
# Update the facts
|
||||
update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts)
|
||||
context.append(UserMessage(content=update_facts_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(context)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._facts = response.content
|
||||
context.append(AssistantMessage(content=self._facts, source=self._name))
|
||||
|
||||
# Update the plan
|
||||
update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description)
|
||||
context.append(UserMessage(content=update_plan_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(context)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
self._plan = response.content
|
||||
|
||||
async def _prepare_final_answer(self, reason: str) -> None:
|
||||
|
||||
context = []
|
||||
for m in self._message_thread:
|
||||
if m.source == self._name:
|
||||
context.append(AssistantMessage(content=m.content, source=m.source))
|
||||
else:
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
|
||||
# Get the final answer
|
||||
final_answer_prompt = self._get_final_answer_prompt(self._task)
|
||||
context.append(UserMessage(content=final_answer_prompt, source=self._name))
|
||||
|
||||
response = await self._model_client.create(context)
|
||||
message = TextMessage(
|
||||
content=response.content,
|
||||
source=self._name
|
||||
)
|
||||
|
||||
self._message_thread.append(message) # My copy
|
||||
|
||||
# Log it
|
||||
await self.publish_message(
|
||||
GroupChatMessage(message=message),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
|
||||
# Broadcast
|
||||
await self.publish_message(
|
||||
GroupChatAgentResponse(agent_response=Response(chat_message=message)),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
)
|
||||
|
||||
# Signal termination
|
||||
await self.publish_message(
|
||||
GroupChatTermination(message=StopMessage(content=reason, source=self._name)),
|
||||
topic_id=DefaultTopicId(type=self._output_topic_type),
|
||||
)
|
||||
|
|
|
@ -53,6 +53,7 @@ Here is the plan to follow as best as possible:
|
|||
{plan}
|
||||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT = """
|
||||
Recall we are working on the following request:
|
||||
|
||||
|
@ -97,7 +98,7 @@ Please output an answer in pure JSON format according to the following schema. T
|
|||
"""
|
||||
|
||||
|
||||
ORCHESTRATOR_UPDATE_FACTS_PROMPT = """As a reminder, we are working to solve the following task:
|
||||
ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT = """As a reminder, we are working to solve the following task:
|
||||
|
||||
{task}
|
||||
|
||||
|
@ -108,12 +109,14 @@ Here is the old fact sheet:
|
|||
{facts}
|
||||
"""
|
||||
|
||||
ORCHESTRATOR_UPDATE_PLAN_PROMPT = """Please briefly explain what went wrong on this last run (the root cause of the failure), and then come up with a new plan that takes steps and/or includes hints to overcome prior challenges and especially avoids repeating the same mistakes. As before, the new plan should be concise, be expressed in bullet-point form, and consider the following team composition (do not involve any other outside people since we cannot contact anyone else):
|
||||
|
||||
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT = """Please briefly explain what went wrong on this last run (the root cause of the failure), and then come up with a new plan that takes steps and/or includes hints to overcome prior challenges and especially avoids repeating the same mistakes. As before, the new plan should be concise, be expressed in bullet-point form, and consider the following team composition (do not involve any other outside people since we cannot contact anyone else):
|
||||
|
||||
{team}
|
||||
"""
|
||||
|
||||
ORCHESTRATOR_GET_FINAL_ANSWER = """
|
||||
|
||||
ORCHESTRATOR_FINAL_ANSWER_PROMPT = """
|
||||
We are working on the following task:
|
||||
{task}
|
||||
|
||||
|
|
Loading…
Reference in New Issue