mirror of https://github.com/microsoft/autogen.git
Add a generic `stop_when` to runtime (#431)
* Add stop_when * Format --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
042958e3ab
commit
dacd290f1e
|
@ -98,30 +98,41 @@ class RunContext:
|
|||
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
||||
self._runtime = runtime
|
||||
self._run_state = RunContext.RunState.RUNNING
|
||||
self._end_condition: Callable[[], bool] = self._stop_when_cancelled
|
||||
self._run_task = asyncio.create_task(self._run())
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _run(self) -> None:
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._run_state == RunContext.RunState.CANCELLED:
|
||||
if self._end_condition():
|
||||
return
|
||||
elif self._run_state == RunContext.RunState.UNTIL_IDLE:
|
||||
if self._runtime.idle:
|
||||
return
|
||||
|
||||
await self._runtime.process_next()
|
||||
|
||||
async def stop(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.CANCELLED
|
||||
self._end_condition = self._stop_when_cancelled
|
||||
await self._run_task
|
||||
|
||||
async def stop_when_idle(self) -> None:
|
||||
async with self._lock:
|
||||
self._run_state = RunContext.RunState.UNTIL_IDLE
|
||||
self._end_condition = self._stop_when_idle
|
||||
await self._run_task
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
async with self._lock:
|
||||
self._end_condition = condition
|
||||
await self._run_task
|
||||
|
||||
def _stop_when_cancelled(self) -> bool:
|
||||
return self._run_state == RunContext.RunState.CANCELLED
|
||||
|
||||
def _stop_when_idle(self) -> bool:
|
||||
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle
|
||||
|
||||
|
||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
def __init__(self, *, intervention_handlers: List[InterventionHandler] | None = None) -> None:
|
||||
|
@ -449,6 +460,13 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||
await self._run_context.stop_when_idle()
|
||||
self._run_context = None
|
||||
|
||||
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||
"""Stop the runtime message processing loop when the condition is met."""
|
||||
if self._run_context is None:
|
||||
raise RuntimeError("Runtime is not started")
|
||||
await self._run_context.stop_when(condition)
|
||||
self._run_context = None
|
||||
|
||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||
return (await self._get_agent(agent)).metadata
|
||||
|
||||
|
|
Loading…
Reference in New Issue