allow class associated subscriptions to be skipped on register (#587)

* allow class associated subscriptions to be skipped on register

* format
This commit is contained in:
Jack Gerrits 2024-09-19 15:50:59 -04:00 committed by GitHub
parent 7f25d28aac
commit d6dce9ebb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 14 deletions

View File

@ -144,23 +144,28 @@ class BaseAgent(ABC, Agent):
@classmethod
async def register(
cls, runtime: AgentRuntime, type: str, factory: Callable[[], Self | Awaitable[Self]]
cls,
runtime: AgentRuntime,
type: str,
factory: Callable[[], Self | Awaitable[Self]],
*,
skip_class_subscriptions: bool = False,
) -> AgentType:
agent_type = AgentType(type)
with SubscriptionInstantiationContext.populate_context(agent_type):
subscriptions = []
for unbound_subscription in cls._unbound_subscriptions():
subscriptions_list_result = unbound_subscription()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result
subscriptions.extend(subscriptions_list)
agent_type = await runtime.register_factory(type=agent_type, agent_factory=factory, expected_class=cls)
for subscription in subscriptions:
await runtime.add_subscription(subscription)
if not skip_class_subscriptions:
with SubscriptionInstantiationContext.populate_context(agent_type):
subscriptions = []
for unbound_subscription in cls._unbound_subscriptions():
subscriptions_list_result = unbound_subscription()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result
subscriptions.extend(subscriptions_list)
for subscription in subscriptions:
await runtime.add_subscription(subscription)
# TODO: deduplication
for _message_type, serializer in cls._handles_types():