mirror of https://github.com/microsoft/autogen.git
Add a generic `stop_when` to runtime (#431)
* Add stop_when * Format --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
042958e3ab
commit
dacd290f1e
|
@ -98,30 +98,41 @@ class RunContext:
|
||||||
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
def __init__(self, runtime: SingleThreadedAgentRuntime) -> None:
|
||||||
self._runtime = runtime
|
self._runtime = runtime
|
||||||
self._run_state = RunContext.RunState.RUNNING
|
self._run_state = RunContext.RunState.RUNNING
|
||||||
|
self._end_condition: Callable[[], bool] = self._stop_when_cancelled
|
||||||
self._run_task = asyncio.create_task(self._run())
|
self._run_task = asyncio.create_task(self._run())
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
async def _run(self) -> None:
|
async def _run(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
if self._run_state == RunContext.RunState.CANCELLED:
|
if self._end_condition():
|
||||||
return
|
return
|
||||||
elif self._run_state == RunContext.RunState.UNTIL_IDLE:
|
|
||||||
if self._runtime.idle:
|
|
||||||
return
|
|
||||||
|
|
||||||
await self._runtime.process_next()
|
await self._runtime.process_next()
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._run_state = RunContext.RunState.CANCELLED
|
self._run_state = RunContext.RunState.CANCELLED
|
||||||
|
self._end_condition = self._stop_when_cancelled
|
||||||
await self._run_task
|
await self._run_task
|
||||||
|
|
||||||
async def stop_when_idle(self) -> None:
|
async def stop_when_idle(self) -> None:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._run_state = RunContext.RunState.UNTIL_IDLE
|
self._run_state = RunContext.RunState.UNTIL_IDLE
|
||||||
|
self._end_condition = self._stop_when_idle
|
||||||
await self._run_task
|
await self._run_task
|
||||||
|
|
||||||
|
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
self._end_condition = condition
|
||||||
|
await self._run_task
|
||||||
|
|
||||||
|
def _stop_when_cancelled(self) -> bool:
|
||||||
|
return self._run_state == RunContext.RunState.CANCELLED
|
||||||
|
|
||||||
|
def _stop_when_idle(self) -> bool:
|
||||||
|
return self._run_state == RunContext.RunState.UNTIL_IDLE and self._runtime.idle
|
||||||
|
|
||||||
|
|
||||||
class SingleThreadedAgentRuntime(AgentRuntime):
|
class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
def __init__(self, *, intervention_handlers: List[InterventionHandler] | None = None) -> None:
|
def __init__(self, *, intervention_handlers: List[InterventionHandler] | None = None) -> None:
|
||||||
|
@ -449,6 +460,13 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||||
await self._run_context.stop_when_idle()
|
await self._run_context.stop_when_idle()
|
||||||
self._run_context = None
|
self._run_context = None
|
||||||
|
|
||||||
|
async def stop_when(self, condition: Callable[[], bool]) -> None:
|
||||||
|
"""Stop the runtime message processing loop when the condition is met."""
|
||||||
|
if self._run_context is None:
|
||||||
|
raise RuntimeError("Runtime is not started")
|
||||||
|
await self._run_context.stop_when(condition)
|
||||||
|
self._run_context = None
|
||||||
|
|
||||||
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
async def agent_metadata(self, agent: AgentId) -> AgentMetadata:
|
||||||
return (await self._get_agent(agent)).metadata
|
return (await self._get_agent(agent)).metadata
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue