Add a generic `stop_when` to runtime (#431)

* Add stop_when

* Format

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Aamir 2024-08-30 12:28:51 -07:00 committed by GitHub
parent 042958e3ab
commit dacd290f1e
1 changed files with 22 additions and 4 deletions

View File

@ -98,30 +98,41 @@ class RunContext:
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
self._runtime = runtime
self._run_state = RunContext.RunState.RUNNING
self._end_condition: Callable[[], bool] = self._stop_when_cancelled
self._run_task = asyncio.create_task(self._run())
self._lock = asyncio.Lock()
async def _run(self) -> None:
while True:
async with self._lock:
if self._run_state == RunContext.RunState.CANCELLED:
if self._end_condition():
return
elif self._run_state == RunContext.RunState.UNTIL_IDLE:
if self._runtime.idle:
return
await self._runtime.process_next()
async def stop(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.CANCELLED
self._end_condition = self._stop_when_cancelled
await self._run_task
async def stop_when_idle(self) -> None:
async with self._lock:
self._run_state = RunContext.RunState.UNTIL_IDLE
self._end_condition = self._stop_when_idle
await self._run_task
async def stop_when(self, condition: Callable[[], bool]) -> None:
async with self._lock:
self._end_condition = condition
await self._run_task
def _stop_when_cancelled(self) -> bool:
return self._run_state == RunContext.RunState.CANCELLED
def _stop_when_idle(self) -> bool:
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle
class SingleThreadedAgentRuntime(AgentRuntime):
def __init__(self, *, intervention_handlers: List[InterventionHandler] | None = None) -> None:
@ -449,6 +460,13 @@ class SingleThreadedAgentRuntime(AgentRuntime):
await self._run_context.stop_when_idle()
self._run_context = None
async def stop_when(self, condition: Callable[[], bool]) -> None:
"""Stop the runtime message processing loop when the condition is met."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when(condition)
self._run_context = None
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
return (await self._get_agent(agent)).metadata