Files
a2a/a2a_pack/context.py

264 lines
8.4 KiB
Python

"""Runtime context handed to skill handlers.
The same agent code runs unchanged on local dev, Docker, Kubernetes, and
hosted runtimes — the runtime provides a concrete :class:`RunContext` that
implements artifact storage, secret access, streaming, and cancellation.
"""
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Generic, Sequence, TypeVar
from pydantic import BaseModel
from .a2a_client import A2AClient, CallResult
from .discovery import DiscoveryClient
from .sandbox import SandboxClient, SandboxUnavailable
from .workspace import WorkspaceClient
AuthT = TypeVar("AuthT", bound=BaseModel)
class CancelledByCaller(RuntimeError):
"""Raised by :meth:`RunContext.check_cancelled` when the caller cancelled."""
class MissingScopes(PermissionError):
"""Raised by :meth:`RunContext.require_scopes` when caller lacks scopes."""
def __init__(self, missing: Sequence[str]) -> None:
self.missing = tuple(missing)
super().__init__(f"missing scopes: {sorted(self.missing)}")
@dataclass(frozen=True)
class ArtifactRef:
"""Opaque handle to a stored artifact (blob, file, etc.)."""
name: str
uri: str
mime_type: str
size_bytes: int
@dataclass(frozen=True)
class AgentEvent:
"""A structured event emitted during a skill run."""
kind: str
payload: dict[str, Any] = field(default_factory=dict)
class RunContext(ABC, Generic[AuthT]):
"""Per-invocation context.
A new context is constructed by the runtime for every skill call. It
carries caller identity (``auth``), the task identity, and runtime
capabilities (artifacts, secrets, streaming, cancellation).
Agents must depend only on this abstract interface, never on a concrete
runtime implementation.
"""
task_id: str
auth: AuthT
@abstractmethod
async def emit_event(self, event: AgentEvent) -> None:
"""Publish a structured event to subscribers (UI, logs, traces)."""
@abstractmethod
async def write_artifact(
self, name: str, data: bytes, mime_type: str
) -> ArtifactRef:
"""Persist ``data`` as a named artifact and return a reference."""
@abstractmethod
async def check_cancelled(self) -> None:
"""Raise :class:`CancelledByCaller` if the caller cancelled."""
@abstractmethod
def secret(self, name: str) -> str:
"""Look up a runtime-injected secret by logical name."""
@property
@abstractmethod
def workspace(self) -> WorkspaceClient:
"""Negotiation surface for workspace access.
Raises if the agent's :attr:`A2AAgent.workspace_access` is disabled.
"""
@property
@abstractmethod
def sandbox(self) -> SandboxClient:
"""Code-execution surface (microsandbox-backed by default).
Raises :class:`SandboxUnavailable` if the runtime did not attach a
sandbox client to this context (e.g. local dev with no host daemon).
"""
@property
@abstractmethod
def discover(self) -> DiscoveryClient:
"""Registry-backed discovery: find other agents by tag/capability/skill."""
async def call(
self,
target: str,
skill: str,
*,
args: dict[str, Any] | None = None,
grant: str | None = None,
timeout: float | None = None,
) -> CallResult:
"""Invoke another agent's skill via the runtime's :class:`A2AClient`.
``target`` is whatever the underlying client expects — an HTTP URL
for :class:`HttpA2AClient`, an agent name for in-process routing.
Pair with :meth:`WorkspaceClient.delegate` to hand a scoped
workspace grant to the callee.
"""
client = self._a2a_client()
return await client.call(target, skill, args=args, grant=grant, timeout=timeout)
@abstractmethod
def _a2a_client(self) -> A2AClient:
"""Return the runtime's outbound A2A client (or raise if absent)."""
# --- concrete helpers built on emit_event ---
async def emit_progress(self, message: str) -> None:
"""Emit a human-readable progress event."""
await self.emit_event(AgentEvent(kind="progress", payload={"message": message}))
async def emit_text_delta(self, text: str) -> None:
"""Emit a streamed token chunk (for LLM-style streaming output)."""
await self.emit_event(AgentEvent(kind="text_delta", payload={"text": text}))
async def emit_artifact(self, ref: ArtifactRef) -> None:
"""Notify subscribers that a new artifact is available."""
await self.emit_event(
AgentEvent(
kind="artifact",
payload={
"name": ref.name,
"uri": ref.uri,
"mime_type": ref.mime_type,
"size_bytes": ref.size_bytes,
},
)
)
async def emit_error(self, message: str, *, code: str | None = None) -> None:
"""Emit a structured error event (does not raise)."""
await self.emit_event(
AgentEvent(kind="error", payload={"message": message, "code": code})
)
def require_scopes(self, required: Sequence[str]) -> None:
"""Raise :class:`MissingScopes` if ``self.auth`` lacks any required scope.
Auth models without a ``scopes`` attribute (e.g. :class:`NoAuth`) are
treated as having an empty scope set.
"""
if not required:
return
auth_scopes = set(getattr(self.auth, "scopes", ()) or ())
missing = [s for s in required if s not in auth_scopes]
if missing:
raise MissingScopes(missing)
class LocalRunContext(RunContext[AuthT]):
"""In-memory context for local dev and tests.
Stores events and artifacts in lists/dicts. Secrets come from a plain
mapping. Cancellation is driven by an :class:`asyncio.Event`.
"""
def __init__(
self,
*,
auth: AuthT,
task_id: str = "local-task",
secrets: dict[str, str] | None = None,
workspace: WorkspaceClient | None = None,
sandbox: SandboxClient | None = None,
a2a: A2AClient | None = None,
discover: DiscoveryClient | None = None,
) -> None:
self.task_id = task_id
self.auth = auth
self._secrets: dict[str, str] = dict(secrets or {})
self._workspace = workspace
self._sandbox = sandbox
self._a2a = a2a
self._discover = discover
self._cancel = asyncio.Event()
self.events: list[AgentEvent] = []
self.artifacts: dict[str, bytes] = {}
@property
def workspace(self) -> WorkspaceClient:
if self._workspace is None:
raise PermissionError(
"no workspace bound to this context; agent did not declare "
"workspace_access or runtime did not provision one"
)
return self._workspace
@property
def sandbox(self) -> SandboxClient:
if self._sandbox is None:
raise SandboxUnavailable(
"no sandbox client attached to this context; "
"the runtime layer must provision one"
)
return self._sandbox
@property
def discover(self) -> DiscoveryClient:
if self._discover is None:
raise PermissionError(
"no discovery client attached; runtime must provision one"
)
return self._discover
def _a2a_client(self) -> A2AClient:
if self._a2a is None:
raise PermissionError(
"no A2A client attached; runtime must provision one before "
"ctx.call(...) can be used"
)
return self._a2a
async def emit_event(self, event: AgentEvent) -> None:
self.events.append(event)
async def write_artifact(
self, name: str, data: bytes, mime_type: str
) -> ArtifactRef:
self.artifacts[name] = data
return ArtifactRef(
name=name,
uri=f"memory://{self.task_id}/{name}",
mime_type=mime_type,
size_bytes=len(data),
)
async def check_cancelled(self) -> None:
if self._cancel.is_set():
raise CancelledByCaller(self.task_id)
def cancel(self) -> None:
self._cancel.set()
def secret(self, name: str) -> str:
try:
return self._secrets[name]
except KeyError as exc:
raise KeyError(f"unknown secret: {name!r}") from exc