diff --git a/veadk/a2a/hub/__init__.py b/veadk/a2a/hub/__init__.py new file mode 100644 index 00000000..7f463206 --- /dev/null +++ b/veadk/a2a/hub/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/veadk/a2a/hub/a2a_hub_client.py b/veadk/a2a/hub/a2a_hub_client.py new file mode 100644 index 00000000..bdc77c10 --- /dev/null +++ b/veadk/a2a/hub/a2a_hub_client.py @@ -0,0 +1,76 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +from a2a.types import AgentCard + + +class A2AHubClient: + def __init__(self, server_host: str, server_port: int): + self.server_host = server_host + self.server_port = server_port + self.health_check() + + def health_check(self) -> None: + """Check the health of the server.""" + response = requests.get(f"http://{self.server_host}:{self.server_port}/ping") + assert response.status_code == 200, ( + f"unexpected status code from A2A hub server: {response.status_code}" + ) + + def get_agent_cards( + self, group_id: str, target_agents: list[str] = [] + ) -> list[dict]: + """Get the agent cards of the agents in the group.""" + ret = [] + + response = requests.get( + f"http://{self.server_host}:{self.server_port}/group/{group_id}/agents" + ).json() + agent_infos = response["agent_infos"] + for agent_info in agent_infos: + agent_id = agent_info["agent_id"] + if target_agents: + if agent_id in target_agents: + ret.append(agent_info) + else: + ret.append(agent_info) + + return ret + + def register_agent(self, group_id: str, agent_id: str, agent_card: AgentCard): + response = requests.post( + f"http://{self.server_host}:{self.server_port}/register_agent", + json={ + "group_id": group_id, + "agent_id": agent_id, + "agent_card": agent_card.model_dump(), + }, + ) + + assert response.status_code == 200, ( + f"unexpected status code from A2A hub server: {response.status_code}" + ) + + def create_group(self, group_id: str): + response = requests.post( + f"http://{self.server_host}:{self.server_port}/create_group", + params={ + "group_id": group_id, + }, + ) + + assert response.status_code == 200, ( + f"unexpected status code from A2A hub server: {response.status_code}" + ) diff --git a/veadk/a2a/hub/a2a_hub_server.py b/veadk/a2a/hub/a2a_hub_server.py new file mode 100644 index 00000000..b0ae8e43 --- /dev/null +++ b/veadk/a2a/hub/a2a_hub_server.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uvicorn +from fastapi import FastAPI +from fastapi.responses import JSONResponse + +from veadk.a2a.hub.models import ( + AgentInformation, + GetAgentResponse, + GetAgentsResponse, + GetGroupsResponse, + RegisterAgentRequest, + RegisterAgentResponse, + RegisterGroupResponse, +) + + +class A2AHubServer: + def __init__(self): + self.app = FastAPI() + + self.groups: list[str] = [] + + # group_id -> agent_id -> agent_card + self.agent_cards: dict[str, dict[str, dict]] = {} + + @self.app.get("/ping") + def ping() -> JSONResponse: + return JSONResponse(content={"msg": "pong!"}) + + @self.app.post("/create_group") + def create_group(group_id: str) -> RegisterGroupResponse: + """Create a group.""" + self.groups.append(group_id) + self.agent_cards[group_id] = {} + return RegisterGroupResponse(group_id=group_id) + + @self.app.post("/register_agent") + def register_agent( + request: RegisterAgentRequest, + ) -> RegisterAgentResponse: + """Register an agent to a specified group.""" + if request.group_id not in self.groups: + return RegisterAgentResponse( + err_code=1, msg=f"group {request.group_id} not exist" + ) + if request.agent_id in self.agent_cards[request.group_id]: + return RegisterAgentResponse( + err_code=1, msg=f"agent {request.agent_id} already exist" + ) + + self.agent_cards[request.group_id][request.agent_id] = request.agent_card + return RegisterAgentResponse( + group_id=request.group_id, + agent_id=request.agent_id, + agent_card=self.agent_cards[request.group_id][request.agent_id], + ) + + @self.app.get("/group/{group_id}/agents") + def agents(group_id: str) -> GetAgentsResponse: + """Get all agents in a specified group.""" + if group_id not in self.groups: + return GetAgentsResponse(err_code=1, msg=f"group {group_id} not exist") + + agent_infos = [ + AgentInformation(agent_id=agent_id, agent_card=agent_card) + for agent_id, agent_card in self.agent_cards[group_id].items() + ] + return GetAgentsResponse(group_id=group_id, agent_infos=agent_infos) + + @self.app.get("/group/{group_id}/agent/{agent_id}") + def agent(group_id: str, agent_id: str) -> GetAgentResponse: + """Get the agent card of a specified agent in a specified group.""" + if group_id not in self.groups: + return GetAgentResponse(err_code=1, msg=f"group {group_id} not exist") + if agent_id not in self.agent_cards[group_id]: + return GetAgentResponse( + err_code=1, + msg=f"agent {agent_id} in group {group_id} not exist", + ) + return GetAgentResponse( + agent_id=agent_id, + agent_card=self.agent_cards[group_id][agent_id], + ) + + @self.app.get("/groups") + def groups() -> GetGroupsResponse: + """Get all registered groups.""" + return GetGroupsResponse(group_ids=self.groups) + + def serve(self, **kwargs): + uvicorn.run(self.app, **kwargs) diff --git a/veadk/a2a/hub/models.py b/veadk/a2a/hub/models.py new file mode 100644 index 00000000..23737494 --- /dev/null +++ b/veadk/a2a/hub/models.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel, Field + + +class BaseResponse(BaseModel): + err_code: int = 0 + + msg: str = "" + """The message of the response.""" + + +class RegisterGroupResponse(BaseResponse): + group_id: str = "" + """The id of the group.""" + + +class RegisterAgentRequest(BaseResponse): + group_id: str + """Target group id.""" + + agent_id: str + """The id of the agent.""" + + agent_card: dict + """The agent card of the agent in json format.""" + + +class RegisterAgentResponse(BaseResponse): + group_id: str = "" + """Target group id.""" + + agent_id: str = "" + """The id of the agent.""" + + agent_card: dict = Field(default_factory=dict) + """The agent card of the agent in json format.""" + + +class AgentInformation(BaseModel): + agent_id: str = "" + """The id of the agent.""" + + agent_card: dict = Field(default_factory=dict) + """The agent card of the agent in json format.""" + + +class GetAgentsResponse(BaseResponse): + group_id: str = "" + """Target group id.""" + + agent_infos: list[AgentInformation] = Field(default_factory=list) + """The agent cards of the agents in json format.""" + + +class GetAgentResponse(BaseResponse): + agent_id: str = "" + """The id of the agent.""" + + agent_card: dict = Field(default_factory=dict) + """The agent card of the agent in json format.""" + + +class GetGroupsResponse(BaseResponse): + group_ids: list[str] = Field(default_factory=list) + """The ids of the groups."""