Added termination.

This commit is contained in:
Adam Fourney 2024-11-12 22:21:31 -08:00
parent dc172b2d0b
commit 33e9670cb6
2 changed files with 131 additions and 7 deletions

View File

@ -35,7 +35,10 @@ from ._prompts import (
ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT,
ORCHESTRATOR_TASK_LEDGER_PLAN_PROMPT,
ORCHESTRATOR_TASK_LEDGER_FULL_PROMPT,
ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT,
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT,
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT,
ORCHESTRATOR_FINAL_ANSWER_PROMPT,
)
#class LedgerOrchestratorManager(BaseGroupChatManager):
@ -48,6 +51,8 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
participant_topic_types: List[str],
participant_descriptions: List[str],
model_client: ChatCompletionClient,
max_rounds: int = 20,
max_stalls: int = 3,
):
super().__init__(description="Group chat manager")
self._group_topic_type = group_topic_type
@ -64,9 +69,15 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
self._name = "orchestrator"
self._model_client = model_client
self._max_rounds = max_rounds
self._max_stalls = max_stalls
self._task = None
self._facts = None
self._plan = None
self._n_rounds = 0
self._n_stalls = 0
self._team_description = "\n".join(
[
f"{topic_type}: {description}".strip()
@ -88,6 +99,16 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
def _get_progress_ledger_prompt(self, task: str, team: str, names: List[str]) -> str:
return ORCHESTRATOR_PROGRESS_LEDGER_PROMPT.format(task=task, team=team, names=", ".join(names))
def _get_task_ledger_facts_update_prompt(self, task: str, facts: str) -> str:
return ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT.format(task=task, facts=facts)
def _get_task_ledger_plan_update_prompt(self, team: str) -> str:
return ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT.format(team=team)
def _get_final_answer_prompt(self, task: str) -> str:
return ORCHESTRATOR_FINAL_ANSWER_PROMPT.format(task=task)
@event
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
"""Handle the start of a group chat by selecting a speaker to start the conversation."""
@ -124,6 +145,7 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
self._plan = response.content
# Kick things off
self._n_stalls = 0
await self._reenter_inner_loop()
@event
@ -152,14 +174,23 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
async def _reenter_inner_loop(self):
# TODO: Reset the agents
# Broadcast the new plan
# Prepare the ledger
ledger_message = TextMessage(
content=self._get_task_ledger_full_prompt(self._task, self._team_description, self._facts, self._plan),
source=self._name
)
self._message_thread.append(ledger_message) # My copy
await self.publish_message( # Broadcast
# Save my copy
self._message_thread.append(ledger_message)
# Log it
await self.publish_message(
GroupChatMessage(message=ledger_message),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Broadcast
await self.publish_message(
GroupChatAgentResponse(agent_response=Response(chat_message=ledger_message)),
topic_id=DefaultTopicId(type=self._group_topic_type),
)
@ -170,6 +201,12 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
async def _orchestrate_step(self) -> None:
# Check if we reached the maximum number of rounds
if self._n_rounds > self._max_rounds:
await self._prepare_final_answer("Max rounds reached.")
return
self._n_rounds += 1
# Update the progress ledger
context = []
for m in self._message_thread:
@ -188,6 +225,23 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
progress_ledger = json.loads(response.content)
# Check for task completion
if progress_ledger["is_request_satisfied"]["answer"]:
await self._prepare_final_answer(progress_ledger["is_request_satisfied"]["reason"])
return
# Check for stalling
if progress_ledger["is_progress_being_made"]["answer"] or progress_ledger["is_in_loop"]["answer"]:
self._n_stalls += 1
else:
self._n_stalls = max(0, self._n_stalls-1)
# Too much stalling
if self._n_stalls >= self._max_stalls:
await self._update_task_ledger()
await self._reenter_inner_loop()
return
# Broadcst the next step
message = TextMessage(
content=progress_ledger["instruction_or_question"]["answer"],
@ -207,6 +261,73 @@ class LedgerOrchestratorManager(SequentialRoutedAgent, ABC):
topic_id=DefaultTopicId(type=self._group_topic_type),
)
# Request that it be completed
# Request that the step be completed
next_speaker = progress_ledger["next_speaker"]["answer"]
await self.publish_message(GroupChatRequestPublish(), topic_id=DefaultTopicId(type=next_speaker))
async def _update_task_ledger(self) -> None:
context = []
for m in self._message_thread:
if m.source == self._name:
context.append(AssistantMessage(content=m.content, source=m.source))
else:
context.append(UserMessage(content=m.content, source=m.source))
# Update the facts
update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts)
context.append(UserMessage(content=update_facts_prompt, source=self._name))
response = await self._model_client.create(context)
assert isinstance(response.content, str)
self._facts = response.content
context.append(AssistantMessage(content=self._facts, source=self._name))
# Update the plan
update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description)
context.append(UserMessage(content=update_plan_prompt, source=self._name))
response = await self._model_client.create(context)
assert isinstance(response.content, str)
self._plan = response.content
async def _prepare_final_answer(self, reason: str) -> None:
context = []
for m in self._message_thread:
if m.source == self._name:
context.append(AssistantMessage(content=m.content, source=m.source))
else:
context.append(UserMessage(content=m.content, source=m.source))
# Get the final answer
final_answer_prompt = self._get_final_answer_prompt(self._task)
context.append(UserMessage(content=final_answer_prompt, source=self._name))
response = await self._model_client.create(context)
message = TextMessage(
content=response.content,
source=self._name
)
self._message_thread.append(message) # My copy
# Log it
await self.publish_message(
GroupChatMessage(message=message),
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Broadcast
await self.publish_message(
GroupChatAgentResponse(agent_response=Response(chat_message=message)),
topic_id=DefaultTopicId(type=self._group_topic_type),
)
# Signal termination
await self.publish_message(
GroupChatTermination(message=StopMessage(content=reason, source=self._name)),
topic_id=DefaultTopicId(type=self._output_topic_type),
)

View File

@ -53,6 +53,7 @@ Here is the plan to follow as best as possible:
{plan}
"""
ORCHESTRATOR_PROGRESS_LEDGER_PROMPT = """
Recall we are working on the following request:
@ -97,7 +98,7 @@ Please output an answer in pure JSON format according to the following schema. T
"""
ORCHESTRATOR_UPDATE_FACTS_PROMPT = """As a reminder, we are working to solve the following task:
ORCHESTRATOR_TASK_LEDGER_FACTS_UPDATE_PROMPT = """As a reminder, we are working to solve the following task:
{task}
@ -108,12 +109,14 @@ Here is the old fact sheet:
{facts}
"""
ORCHESTRATOR_UPDATE_PLAN_PROMPT = """Please briefly explain what went wrong on this last run (the root cause of the failure), and then come up with a new plan that takes steps and/or includes hints to overcome prior challenges and especially avoids repeating the same mistakes. As before, the new plan should be concise, be expressed in bullet-point form, and consider the following team composition (do not involve any other outside people since we cannot contact anyone else):
ORCHESTRATOR_TASK_LEDGER_PLAN_UPDATE_PROMPT = """Please briefly explain what went wrong on this last run (the root cause of the failure), and then come up with a new plan that takes steps and/or includes hints to overcome prior challenges and especially avoids repeating the same mistakes. As before, the new plan should be concise, be expressed in bullet-point form, and consider the following team composition (do not involve any other outside people since we cannot contact anyone else):
{team}
"""
ORCHESTRATOR_GET_FINAL_ANSWER = """
ORCHESTRATOR_FINAL_ANSWER_PROMPT = """
We are working on the following task:
{task}