Files
a2a/a2a_pack/grants.py

149 lines
4.7 KiB
Python

"""Signed grant tokens for cross-agent workspace handoff.
A grant is a small, self-contained, signed claim issued by one agent that
the platform (or the receiving agent) can verify without a registry round-trip.
Wire format::
"<base64url(json(payload))>.<base64url(hmac_sha256(secret, payload))>"
The payload describes *what* the callee is allowed to do, *whose* workspace
they can see, and *for how long*. The runtime on the receiving side
materializes a :class:`WorkspaceClient` scoped to that grant.
Auth model is intentionally simple for v1: a shared platform secret signs
every grant. Swap for asymmetric (X.509 / JWKS) when crossing trust domains.
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import os
import secrets
import time
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt
from .workspace import WorkspaceMode
DEFAULT_TTL_SECONDS = 5 * 60
class GrantInvalid(PermissionError):
"""Raised by :func:`verify_grant` when a grant is bad/expired/forged."""
class Grant(BaseModel):
"""The payload of a signed grant token.
A grant binds *who* (issuer) gave *whom* (audience) access to *which*
workspace files (bucket + allow/deny patterns) under *what* mode and
*how long*. The runtime enforces every line of this payload.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
grant_id: str
issuer: str # caller agent name or URL
audience: str # callee agent name or URL
bucket: str # workspace bucket the grant covers
mode: WorkspaceMode = WorkspaceMode.READ_ONLY
allow_patterns: tuple[str, ...] = ("**",)
deny_patterns: tuple[str, ...] = ()
outputs_prefix: str | None = None # if set, callee writes only here
expires_at: NonNegativeInt = 0
issued_at: NonNegativeInt = 0
nonce: str = Field(default_factory=lambda: secrets.token_hex(8))
def _b64encode(b: bytes) -> str:
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
def _b64decode(s: str) -> bytes:
pad = "=" * (-len(s) % 4)
return base64.urlsafe_b64decode(s + pad)
def _platform_secret() -> bytes:
secret = os.environ.get("A2A_PLATFORM_SECRET", "dev-secret-rotate-me")
return secret.encode("utf-8")
def mint_grant(
*,
issuer: str,
audience: str,
bucket: str,
mode: WorkspaceMode = WorkspaceMode.READ_ONLY,
allow_patterns: tuple[str, ...] = ("**",),
deny_patterns: tuple[str, ...] = (),
outputs_prefix: str | None = None,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
secret: bytes | None = None,
) -> tuple[Grant, str]:
"""Build a :class:`Grant` and return it together with its signed token."""
now = int(time.time())
grant = Grant(
grant_id=secrets.token_hex(8),
issuer=issuer,
audience=audience,
bucket=bucket,
mode=mode,
allow_patterns=tuple(allow_patterns),
deny_patterns=tuple(deny_patterns),
outputs_prefix=outputs_prefix,
expires_at=now + ttl_seconds,
issued_at=now,
)
return grant, sign_grant(grant, secret=secret)
def sign_grant(grant: Grant, *, secret: bytes | None = None) -> str:
payload = grant.model_dump_json(exclude_none=False).encode("utf-8")
sig = hmac.new(secret or _platform_secret(), payload, hashlib.sha256).digest()
return f"{_b64encode(payload)}.{_b64encode(sig)}"
def verify_grant(token: str, *, secret: bytes | None = None) -> Grant:
"""Parse + verify ``token``. Raises :class:`GrantInvalid` on any failure.
Checks signature, expiry, and minimal structural shape. Caller-specific
audience checks are layered on top by the server adapter.
"""
if not token or "." not in token:
raise GrantInvalid("malformed grant token")
payload_b64, sig_b64 = token.rsplit(".", 1)
try:
payload = _b64decode(payload_b64)
sig = _b64decode(sig_b64)
except (ValueError, base64.binascii.Error) as exc: # type: ignore[attr-defined]
raise GrantInvalid(f"grant decode failed: {exc}") from exc
expected = hmac.new(secret or _platform_secret(), payload, hashlib.sha256).digest()
if not hmac.compare_digest(expected, sig):
raise GrantInvalid("grant signature mismatch")
try:
data = json.loads(payload)
grant = Grant.model_validate(data)
except Exception as exc: # noqa: BLE001
raise GrantInvalid(f"grant payload invalid: {exc}") from exc
if grant.expires_at and grant.expires_at < int(time.time()):
raise GrantInvalid(f"grant expired at {grant.expires_at}")
return grant
__all__ = [
"Grant",
"GrantInvalid",
"mint_grant",
"sign_grant",
"verify_grant",
"DEFAULT_TTL_SECONDS",
]