Files
a2a/a2a_pack/agent.py
2026-05-08 21:59:51 -03:00

411 lines
14 KiB
Python

from __future__ import annotations
import inspect
import typing
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, ClassVar, Generic, Sequence, TypeVar
from pydantic import BaseModel, TypeAdapter
from .auth import NoAuth
from .card import AgentCard
from .context import LocalRunContext, RunContext
from .runtime import (
AgentRuntime,
EgressPolicy,
Lifecycle,
Resources,
Sandbox,
SkillPolicy,
State,
)
from .workspace import WorkspaceAccess
ConfigT = TypeVar("ConfigT", bound=BaseModel)
AuthT = TypeVar("AuthT", bound=BaseModel)
_RESERVED_PARAM_NAMES = frozenset({"self", "ctx", "context"})
class _EmptyConfig(BaseModel):
"""Default config model when an agent declares no config."""
class SkillNotFound(KeyError):
"""Raised when invoke() is called with an unknown skill name."""
class SkillInvocationError(RuntimeError):
"""Raised when a skill handler raises during invoke()."""
class SkillInputError(ValueError):
"""Raised when invoke() inputs fail validation against the skill schema."""
@dataclass(frozen=True)
class ParamSpec:
"""Validation metadata for a single skill parameter."""
name: str
adapter: TypeAdapter[Any]
has_default: bool
default: Any = None
@dataclass(frozen=True)
class SkillSpec:
"""Static metadata about a single skill, captured at decoration time."""
name: str
description: str
tags: tuple[str, ...]
scopes: tuple[str, ...]
stream: bool
policy: SkillPolicy
input_schema: dict[str, Any]
output_schema: dict[str, Any]
handler: Callable[..., Awaitable[Any]]
params: tuple[ParamSpec, ...] = field(default_factory=tuple)
output_adapter: TypeAdapter[Any] | None = None
def skill(
*,
name: str | None = None,
description: str = "",
tags: Sequence[str] = (),
scopes: Sequence[str] = (),
stream: bool = False,
timeout_seconds: float | None = None,
idempotent: bool = False,
max_retries: int = 0,
cost_class: str | None = None,
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
"""Mark an :class:`A2AAgent` method as a discoverable skill.
Conventions:
- The handler MUST be ``async def``.
- Its first parameter (after ``self``) MUST be a :class:`RunContext`;
the context is supplied by the runtime and is omitted from the
published input schema.
- Remaining parameters MUST be type-annotated. ``*args`` and ``**kwargs``
are rejected.
"""
def decorator(fn: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
if not inspect.iscoroutinefunction(fn):
raise TypeError(
f"@skill requires an async function: {fn.__qualname__}"
)
sig = inspect.signature(fn)
hints = typing.get_type_hints(fn)
params = list(sig.parameters.values())[1:] # drop self
if not params:
raise TypeError(
f"@skill {fn.__qualname__}: missing RunContext parameter"
)
ctx_param, *rest = params
ctx_hint = hints.get(ctx_param.name)
if ctx_hint is None or not _is_run_context(ctx_hint):
raise TypeError(
f"@skill {fn.__qualname__}: first arg after self must be "
f"annotated as RunContext (got {ctx_hint!r})"
)
properties: dict[str, Any] = {}
required: list[str] = []
param_specs: list[ParamSpec] = []
for p in rest:
if p.kind is inspect.Parameter.VAR_POSITIONAL:
raise TypeError(
f"@skill {fn.__qualname__}: *{p.name} is not allowed"
)
if p.kind is inspect.Parameter.VAR_KEYWORD:
raise TypeError(
f"@skill {fn.__qualname__}: **{p.name} is not allowed"
)
if p.name in _RESERVED_PARAM_NAMES:
raise TypeError(
f"@skill {fn.__qualname__}: reserved param name {p.name!r}"
)
if p.name not in hints:
raise TypeError(
f"@skill {fn.__qualname__}: parameter {p.name!r} is "
f"missing a type annotation"
)
tp = hints[p.name]
adapter: TypeAdapter[Any] = TypeAdapter(tp)
properties[p.name] = adapter.json_schema()
has_default = p.default is not inspect.Parameter.empty
if not has_default:
required.append(p.name)
param_specs.append(
ParamSpec(
name=p.name,
adapter=adapter,
has_default=has_default,
default=None if not has_default else p.default,
)
)
input_schema: dict[str, Any] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
}
return_tp = hints.get("return", Any)
output_adapter: TypeAdapter[Any] = TypeAdapter(return_tp)
spec = SkillSpec(
name=name or fn.__name__,
description=description,
tags=tuple(tags),
scopes=tuple(scopes),
stream=stream,
policy=SkillPolicy(
timeout_seconds=timeout_seconds,
idempotent=idempotent,
max_retries=max_retries,
cost_class=cost_class,
),
input_schema=input_schema,
output_schema=output_adapter.json_schema(),
handler=fn,
params=tuple(param_specs),
output_adapter=output_adapter,
)
fn.__a2a_skill__ = spec # type: ignore[attr-defined]
return fn
return decorator
def _is_run_context(tp: Any) -> bool:
"""True if ``tp`` is :class:`RunContext` or a parametrization of it."""
origin = typing.get_origin(tp) or tp
try:
return isinstance(origin, type) and issubclass(origin, RunContext)
except TypeError:
return False
class _AgentMeta(type):
def __new__(mcs, cls_name, bases, namespace):
cls = super().__new__(mcs, cls_name, bases, namespace)
skills: dict[str, SkillSpec] = {}
for base in bases:
skills.update(getattr(base, "_skills", {}))
for attr in namespace.values():
spec = getattr(attr, "__a2a_skill__", None)
if spec is None:
continue
if spec.name in skills and skills[spec.name].handler is not spec.handler:
# Allow overrides from the same chain (parent → child) but
# forbid two distinct handlers in the same class.
if any(
spec.name in getattr(b, "_skills", {})
and getattr(b, "_skills")[spec.name].handler is spec.handler
for b in bases
):
pass # legitimate override
else:
raise TypeError(
f"duplicate skill name {spec.name!r} in {cls_name}"
)
skills[spec.name] = spec
cls._skills = skills # type: ignore[attr-defined]
return cls
class A2AAgent(Generic[ConfigT, AuthT], metaclass=_AgentMeta):
"""Base class for A2A agents.
Subclasses declare:
- ``name``, ``description`` (and optional ``version``),
- optional ``config_model`` / ``auth_model`` (default to empty / NoAuth),
- deployment metadata: ``required_secrets``, ``required_env``,
``capabilities``, ``input_modes``, ``output_modes``,
- one or more methods decorated with :func:`skill`.
"""
name: ClassVar[str] = ""
description: ClassVar[str] = ""
version: ClassVar[str] = "0.1.0"
config_model: ClassVar[type[BaseModel]] = _EmptyConfig
auth_model: ClassVar[type[BaseModel]] = NoAuth
required_secrets: ClassVar[tuple[str, ...]] = ()
required_env: ClassVar[tuple[str, ...]] = ()
capabilities: ClassVar[dict[str, Any]] = {}
input_modes: ClassVar[tuple[str, ...]] = ("application/json",)
output_modes: ClassVar[tuple[str, ...]] = ("application/json",)
# --- runtime / deployment declaration (read by the platform deployer) ---
# Sandbox is always microsandbox; not exposed as a knob.
lifecycle: ClassVar[Lifecycle] = Lifecycle.EPHEMERAL
state: ClassVar[State] = State.NONE
state_model: ClassVar[type[BaseModel] | None] = None
resources: ClassVar[Resources] = Resources()
concurrency: ClassVar[int] = 1
egress: ClassVar[EgressPolicy] = EgressPolicy()
tools_used: ClassVar[tuple[str, ...]] = ()
workspace_access: ClassVar[WorkspaceAccess] = WorkspaceAccess.none()
_skills: ClassVar[dict[str, SkillSpec]] = {}
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
if not cls.name:
raise TypeError(
f"{cls.__name__}.name must be set as a class attribute"
)
if cls.state is not State.NONE and cls.state_model is None:
raise TypeError(
f"{cls.__name__} declares state={cls.state.value!r} but "
f"state_model is not set"
)
if cls.lifecycle is Lifecycle.EPHEMERAL and cls.state is State.SESSION:
raise TypeError(
f"{cls.__name__}: lifecycle=ephemeral is incompatible with "
f"state=session"
)
@classmethod
def runtime(cls) -> AgentRuntime:
"""Aggregate the class-level runtime declaration.
``sandbox`` is always :attr:`Sandbox.MICROSANDBOX`; it is set here
rather than on the class so developers cannot weaken isolation.
"""
return AgentRuntime(
lifecycle=cls.lifecycle,
state=cls.state,
sandbox=Sandbox.MICROSANDBOX,
resources=cls.resources,
concurrency=cls.concurrency,
egress=cls.egress,
tools_used=cls.tools_used,
)
def __init__(self, config: ConfigT | dict[str, Any] | None = None) -> None:
validated = type(self).config_model.model_validate(config or {})
self.config: ConfigT = typing.cast(ConfigT, validated)
@property
def skills(self) -> dict[str, SkillSpec]:
return type(self)._skills
async def startup(self, ctx: RunContext[AuthT]) -> None:
"""Called once before the first invocation. Override to set up state."""
async def shutdown(self, ctx: RunContext[AuthT]) -> None:
"""Called once before the agent process exits. Override to tear down."""
async def health(self) -> bool:
"""Lightweight liveness check. Override to add real probes."""
return True
async def invoke(
self,
skill_name: str,
ctx: RunContext[AuthT],
/,
**kwargs: Any,
) -> Any:
"""Invoke a skill with caller-supplied kwargs.
Inputs are validated and coerced via the skill's pydantic schema.
Required scopes are enforced against ``ctx.auth`` before the handler
runs. The raw handler return value is returned (Python-typed).
"""
spec = self.skills.get(skill_name)
if spec is None:
raise SkillNotFound(skill_name)
ctx.require_scopes(spec.scopes)
try:
validated = self._validate_inputs(spec, kwargs)
except Exception as exc:
raise SkillInputError(
f"invalid input for skill {spec.name!r}: {exc}"
) from exc
try:
return await spec.handler(self, ctx, **validated)
except SkillInputError:
raise
except Exception as exc:
raise SkillInvocationError(
f"skill {spec.name!r} raised {type(exc).__name__}: {exc}"
) from exc
async def invoke_json(
self,
skill_name: str,
ctx: RunContext[AuthT],
payload: dict[str, Any],
) -> Any:
"""Runtime-facing invoke: takes JSON-shaped payload, returns JSON-shaped result."""
spec = self.skills.get(skill_name)
if spec is None:
raise SkillNotFound(skill_name)
result = await self.invoke(skill_name, ctx, **payload)
if spec.output_adapter is None:
return result
return spec.output_adapter.dump_python(result, mode="json")
async def local_invoke(
self,
skill_name: str,
/,
*,
auth: AuthT | None = None,
secrets: dict[str, str] | None = None,
task_id: str = "local-task",
workspace: Any = None, # WorkspaceClient or None
**kwargs: Any,
) -> Any:
"""Convenience harness: build a :class:`LocalRunContext` and invoke.
Useful in tests and notebooks. ``auth`` defaults to a default-constructed
instance of the agent's ``auth_model`` (works for :class:`NoAuth`; pass
an explicit instance for auth models with required fields). Pass
``workspace=`` to bind a :class:`WorkspaceClient`.
"""
if auth is None:
auth = typing.cast(AuthT, type(self).auth_model())
ctx: LocalRunContext[AuthT] = LocalRunContext(
auth=auth, secrets=secrets, task_id=task_id, workspace=workspace
)
return await self.invoke(skill_name, ctx, **kwargs)
def card(self) -> AgentCard:
return AgentCard.from_agent(self)
@staticmethod
def _validate_inputs(
spec: SkillSpec, kwargs: dict[str, Any]
) -> dict[str, Any]:
known = {p.name for p in spec.params}
unknown = set(kwargs) - known
if unknown:
raise ValueError(f"unknown parameters: {sorted(unknown)}")
validated: dict[str, Any] = {}
for p in spec.params:
if p.name in kwargs:
validated[p.name] = p.adapter.validate_python(kwargs[p.name])
elif not p.has_default:
raise ValueError(f"missing required parameter: {p.name!r}")
# else: omit so the handler's own default applies
return validated