Files
a2a/a2a_pack/a2a_client.py

171 lines
5.4 KiB
Python

"""Agent-to-agent invocation surface available via ``ctx.call(...)``.
An agent never speaks raw HTTP to another agent. It calls
``ctx.call(target, skill, args, grant=...)`` and the runtime-attached
:class:`A2AClient` handles transport: HTTP for cross-pod, in-memory for
local tests, anything else (gRPC, message bus) for future runtimes.
The grant token (see :mod:`a2a_pack.grants`) is the *only* way to hand
workspace access across agents. Callee-side runtime validates it before
materializing a :class:`WorkspaceClient`.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from .agent import A2AAgent
from .context import RunContext
@dataclass(frozen=True)
class CallResult:
"""What an A2A invocation returns to the calling skill."""
result: Any
events: tuple[dict[str, Any], ...] = ()
artifacts: tuple[dict[str, Any], ...] = ()
grant_id: str | None = None # echoed for audit
class A2AClient(ABC):
"""Transport-shaped agent-to-agent client."""
@abstractmethod
async def call(
self,
target: str,
skill: str,
*,
args: dict[str, Any] | None = None,
grant: str | None = None,
timeout: float | None = None,
) -> CallResult:
"""Invoke ``skill`` on ``target`` and return its :class:`CallResult`.
``target`` is opaque to this layer — for the HTTP impl it's an agent
URL; for the in-memory impl it's an agent name.
"""
# ---------------------------------------------------------------------------
# In-memory: routes calls to A2AAgent instances in the same process. Useful
# for the demo + tests.
# ---------------------------------------------------------------------------
@dataclass
class InMemoryA2AClient(A2AClient):
"""Routes calls to agent instances registered by name.
The receiving agent gets a *new* :class:`RunContext` built by the
``ctx_factory`` callable, so caller and callee don't share state.
Pass ``ctx_factory=lambda agent, grant: ...`` to control how scoped
workspaces / sandboxes are wired in.
"""
agents: dict[str, "A2AAgent"]
ctx_factory: Any = None # Callable[[A2AAgent, str | None], RunContext]
async def call(
self,
target: str,
skill: str,
*,
args: dict[str, Any] | None = None,
grant: str | None = None,
timeout: float | None = None,
) -> CallResult:
if target not in self.agents:
raise KeyError(f"no agent registered: {target!r}")
agent = self.agents[target]
ctx = self.ctx_factory(agent, grant) if self.ctx_factory else None
if ctx is None:
from .context import LocalRunContext
from .auth import NoAuth
ctx = LocalRunContext(auth=NoAuth(), task_id=f"a2a-{target}")
result = await agent.invoke_json(skill, ctx, args or {})
events = tuple(
{"kind": e.kind, "payload": e.payload}
for e in getattr(ctx, "events", ())
)
# surface artifacts captured by LocalRunContext, if present
artifacts: tuple[dict[str, Any], ...] = ()
local_arts = getattr(ctx, "artifacts", None)
if isinstance(local_arts, dict):
artifacts = tuple(
{"name": name, "size_bytes": len(data)}
for name, data in local_arts.items()
)
return CallResult(
result=result,
events=events,
artifacts=artifacts,
grant_id=_grant_id_or_none(grant),
)
# ---------------------------------------------------------------------------
# HTTP: posts to <target>/invoke/<skill> with {arguments, grant} body.
# ---------------------------------------------------------------------------
@dataclass
class HttpA2AClient(A2AClient):
"""A2A client that POSTs to the standard /invoke/{skill} endpoint."""
default_timeout: float = 60.0
async def call(
self,
target: str,
skill: str,
*,
args: dict[str, Any] | None = None,
grant: str | None = None,
timeout: float | None = None,
) -> CallResult:
import httpx # late import: server-side needs no client
body: dict[str, Any] = {"arguments": args or {}}
if grant is not None:
body["grant"] = grant
url = f"{target.rstrip('/')}/invoke/{skill}"
async with httpx.AsyncClient(timeout=timeout or self.default_timeout) as c:
resp = await c.post(url, json=body)
if resp.status_code >= 400:
raise RuntimeError(f"a2a {url} -> {resp.status_code}: {resp.text}")
data = resp.json()
return CallResult(
result=data.get("result"),
events=tuple(data.get("events") or ()),
artifacts=tuple(data.get("artifacts") or ()),
grant_id=_grant_id_or_none(grant),
)
def _grant_id_or_none(grant: str | None) -> str | None:
"""Extract grant_id without re-validating the signature (audit only)."""
if not grant or "." not in grant:
return None
try:
from .grants import _b64decode
payload = _b64decode(grant.rsplit(".", 1)[0])
import json
return json.loads(payload).get("grant_id")
except Exception: # noqa: BLE001
return None
__all__ = [
"A2AClient",
"CallResult",
"HttpA2AClient",
"InMemoryA2AClient",
]