171 lines
5.4 KiB
Python
171 lines
5.4 KiB
Python
"""Agent-to-agent invocation surface available via ``ctx.call(...)``.
|
|
|
|
An agent never speaks raw HTTP to another agent. It calls
|
|
``ctx.call(target, skill, args, grant=...)`` and the runtime-attached
|
|
:class:`A2AClient` handles transport: HTTP for cross-pod, in-memory for
|
|
local tests, anything else (gRPC, message bus) for future runtimes.
|
|
|
|
The grant token (see :mod:`a2a_pack.grants`) is the *only* way to hand
|
|
workspace access across agents. Callee-side runtime validates it before
|
|
materializing a :class:`WorkspaceClient`.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
if TYPE_CHECKING:
|
|
from .agent import A2AAgent
|
|
from .context import RunContext
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CallResult:
|
|
"""What an A2A invocation returns to the calling skill."""
|
|
|
|
result: Any
|
|
events: tuple[dict[str, Any], ...] = ()
|
|
artifacts: tuple[dict[str, Any], ...] = ()
|
|
grant_id: str | None = None # echoed for audit
|
|
|
|
|
|
class A2AClient(ABC):
|
|
"""Transport-shaped agent-to-agent client."""
|
|
|
|
@abstractmethod
|
|
async def call(
|
|
self,
|
|
target: str,
|
|
skill: str,
|
|
*,
|
|
args: dict[str, Any] | None = None,
|
|
grant: str | None = None,
|
|
timeout: float | None = None,
|
|
) -> CallResult:
|
|
"""Invoke ``skill`` on ``target`` and return its :class:`CallResult`.
|
|
|
|
``target`` is opaque to this layer — for the HTTP impl it's an agent
|
|
URL; for the in-memory impl it's an agent name.
|
|
"""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# In-memory: routes calls to A2AAgent instances in the same process. Useful
|
|
# for the demo + tests.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class InMemoryA2AClient(A2AClient):
|
|
"""Routes calls to agent instances registered by name.
|
|
|
|
The receiving agent gets a *new* :class:`RunContext` built by the
|
|
``ctx_factory`` callable, so caller and callee don't share state.
|
|
Pass ``ctx_factory=lambda agent, grant: ...`` to control how scoped
|
|
workspaces / sandboxes are wired in.
|
|
"""
|
|
|
|
agents: dict[str, "A2AAgent"]
|
|
ctx_factory: Any = None # Callable[[A2AAgent, str | None], RunContext]
|
|
|
|
async def call(
|
|
self,
|
|
target: str,
|
|
skill: str,
|
|
*,
|
|
args: dict[str, Any] | None = None,
|
|
grant: str | None = None,
|
|
timeout: float | None = None,
|
|
) -> CallResult:
|
|
if target not in self.agents:
|
|
raise KeyError(f"no agent registered: {target!r}")
|
|
agent = self.agents[target]
|
|
ctx = self.ctx_factory(agent, grant) if self.ctx_factory else None
|
|
if ctx is None:
|
|
from .context import LocalRunContext
|
|
from .auth import NoAuth
|
|
|
|
ctx = LocalRunContext(auth=NoAuth(), task_id=f"a2a-{target}")
|
|
result = await agent.invoke_json(skill, ctx, args or {})
|
|
events = tuple(
|
|
{"kind": e.kind, "payload": e.payload}
|
|
for e in getattr(ctx, "events", ())
|
|
)
|
|
# surface artifacts captured by LocalRunContext, if present
|
|
artifacts: tuple[dict[str, Any], ...] = ()
|
|
local_arts = getattr(ctx, "artifacts", None)
|
|
if isinstance(local_arts, dict):
|
|
artifacts = tuple(
|
|
{"name": name, "size_bytes": len(data)}
|
|
for name, data in local_arts.items()
|
|
)
|
|
return CallResult(
|
|
result=result,
|
|
events=events,
|
|
artifacts=artifacts,
|
|
grant_id=_grant_id_or_none(grant),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HTTP: posts to <target>/invoke/<skill> with {arguments, grant} body.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class HttpA2AClient(A2AClient):
|
|
"""A2A client that POSTs to the standard /invoke/{skill} endpoint."""
|
|
|
|
default_timeout: float = 60.0
|
|
|
|
async def call(
|
|
self,
|
|
target: str,
|
|
skill: str,
|
|
*,
|
|
args: dict[str, Any] | None = None,
|
|
grant: str | None = None,
|
|
timeout: float | None = None,
|
|
) -> CallResult:
|
|
import httpx # late import: server-side needs no client
|
|
|
|
body: dict[str, Any] = {"arguments": args or {}}
|
|
if grant is not None:
|
|
body["grant"] = grant
|
|
url = f"{target.rstrip('/')}/invoke/{skill}"
|
|
async with httpx.AsyncClient(timeout=timeout or self.default_timeout) as c:
|
|
resp = await c.post(url, json=body)
|
|
if resp.status_code >= 400:
|
|
raise RuntimeError(f"a2a {url} -> {resp.status_code}: {resp.text}")
|
|
data = resp.json()
|
|
return CallResult(
|
|
result=data.get("result"),
|
|
events=tuple(data.get("events") or ()),
|
|
artifacts=tuple(data.get("artifacts") or ()),
|
|
grant_id=_grant_id_or_none(grant),
|
|
)
|
|
|
|
|
|
def _grant_id_or_none(grant: str | None) -> str | None:
|
|
"""Extract grant_id without re-validating the signature (audit only)."""
|
|
if not grant or "." not in grant:
|
|
return None
|
|
try:
|
|
from .grants import _b64decode
|
|
|
|
payload = _b64decode(grant.rsplit(".", 1)[0])
|
|
import json
|
|
|
|
return json.loads(payload).get("grant_id")
|
|
except Exception: # noqa: BLE001
|
|
return None
|
|
|
|
|
|
__all__ = [
|
|
"A2AClient",
|
|
"CallResult",
|
|
"HttpA2AClient",
|
|
"InMemoryA2AClient",
|
|
]
|