411 lines
14 KiB
Python
411 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import inspect
|
|
import typing
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Awaitable, Callable, ClassVar, Generic, Sequence, TypeVar
|
|
|
|
from pydantic import BaseModel, TypeAdapter
|
|
|
|
from .auth import NoAuth
|
|
from .card import AgentCard
|
|
from .context import LocalRunContext, RunContext
|
|
from .runtime import (
|
|
AgentRuntime,
|
|
EgressPolicy,
|
|
Lifecycle,
|
|
Resources,
|
|
Sandbox,
|
|
SkillPolicy,
|
|
State,
|
|
)
|
|
from .workspace import WorkspaceAccess
|
|
|
|
ConfigT = TypeVar("ConfigT", bound=BaseModel)
|
|
AuthT = TypeVar("AuthT", bound=BaseModel)
|
|
|
|
|
|
_RESERVED_PARAM_NAMES = frozenset({"self", "ctx", "context"})
|
|
|
|
|
|
class _EmptyConfig(BaseModel):
|
|
"""Default config model when an agent declares no config."""
|
|
|
|
|
|
class SkillNotFound(KeyError):
|
|
"""Raised when invoke() is called with an unknown skill name."""
|
|
|
|
|
|
class SkillInvocationError(RuntimeError):
|
|
"""Raised when a skill handler raises during invoke()."""
|
|
|
|
|
|
class SkillInputError(ValueError):
|
|
"""Raised when invoke() inputs fail validation against the skill schema."""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ParamSpec:
|
|
"""Validation metadata for a single skill parameter."""
|
|
|
|
name: str
|
|
adapter: TypeAdapter[Any]
|
|
has_default: bool
|
|
default: Any = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SkillSpec:
|
|
"""Static metadata about a single skill, captured at decoration time."""
|
|
|
|
name: str
|
|
description: str
|
|
tags: tuple[str, ...]
|
|
scopes: tuple[str, ...]
|
|
stream: bool
|
|
policy: SkillPolicy
|
|
input_schema: dict[str, Any]
|
|
output_schema: dict[str, Any]
|
|
handler: Callable[..., Awaitable[Any]]
|
|
params: tuple[ParamSpec, ...] = field(default_factory=tuple)
|
|
output_adapter: TypeAdapter[Any] | None = None
|
|
|
|
|
|
def skill(
|
|
*,
|
|
name: str | None = None,
|
|
description: str = "",
|
|
tags: Sequence[str] = (),
|
|
scopes: Sequence[str] = (),
|
|
stream: bool = False,
|
|
timeout_seconds: float | None = None,
|
|
idempotent: bool = False,
|
|
max_retries: int = 0,
|
|
cost_class: str | None = None,
|
|
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
|
|
"""Mark an :class:`A2AAgent` method as a discoverable skill.
|
|
|
|
Conventions:
|
|
|
|
- The handler MUST be ``async def``.
|
|
- Its first parameter (after ``self``) MUST be a :class:`RunContext`;
|
|
the context is supplied by the runtime and is omitted from the
|
|
published input schema.
|
|
- Remaining parameters MUST be type-annotated. ``*args`` and ``**kwargs``
|
|
are rejected.
|
|
"""
|
|
|
|
def decorator(fn: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
|
|
if not inspect.iscoroutinefunction(fn):
|
|
raise TypeError(
|
|
f"@skill requires an async function: {fn.__qualname__}"
|
|
)
|
|
|
|
sig = inspect.signature(fn)
|
|
hints = typing.get_type_hints(fn)
|
|
params = list(sig.parameters.values())[1:] # drop self
|
|
if not params:
|
|
raise TypeError(
|
|
f"@skill {fn.__qualname__}: missing RunContext parameter"
|
|
)
|
|
|
|
ctx_param, *rest = params
|
|
ctx_hint = hints.get(ctx_param.name)
|
|
if ctx_hint is None or not _is_run_context(ctx_hint):
|
|
raise TypeError(
|
|
f"@skill {fn.__qualname__}: first arg after self must be "
|
|
f"annotated as RunContext (got {ctx_hint!r})"
|
|
)
|
|
|
|
properties: dict[str, Any] = {}
|
|
required: list[str] = []
|
|
param_specs: list[ParamSpec] = []
|
|
for p in rest:
|
|
if p.kind is inspect.Parameter.VAR_POSITIONAL:
|
|
raise TypeError(
|
|
f"@skill {fn.__qualname__}: *{p.name} is not allowed"
|
|
)
|
|
if p.kind is inspect.Parameter.VAR_KEYWORD:
|
|
raise TypeError(
|
|
f"@skill {fn.__qualname__}: **{p.name} is not allowed"
|
|
)
|
|
if p.name in _RESERVED_PARAM_NAMES:
|
|
raise TypeError(
|
|
f"@skill {fn.__qualname__}: reserved param name {p.name!r}"
|
|
)
|
|
if p.name not in hints:
|
|
raise TypeError(
|
|
f"@skill {fn.__qualname__}: parameter {p.name!r} is "
|
|
f"missing a type annotation"
|
|
)
|
|
tp = hints[p.name]
|
|
adapter: TypeAdapter[Any] = TypeAdapter(tp)
|
|
properties[p.name] = adapter.json_schema()
|
|
has_default = p.default is not inspect.Parameter.empty
|
|
if not has_default:
|
|
required.append(p.name)
|
|
param_specs.append(
|
|
ParamSpec(
|
|
name=p.name,
|
|
adapter=adapter,
|
|
has_default=has_default,
|
|
default=None if not has_default else p.default,
|
|
)
|
|
)
|
|
|
|
input_schema: dict[str, Any] = {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": required,
|
|
"additionalProperties": False,
|
|
}
|
|
return_tp = hints.get("return", Any)
|
|
output_adapter: TypeAdapter[Any] = TypeAdapter(return_tp)
|
|
|
|
spec = SkillSpec(
|
|
name=name or fn.__name__,
|
|
description=description,
|
|
tags=tuple(tags),
|
|
scopes=tuple(scopes),
|
|
stream=stream,
|
|
policy=SkillPolicy(
|
|
timeout_seconds=timeout_seconds,
|
|
idempotent=idempotent,
|
|
max_retries=max_retries,
|
|
cost_class=cost_class,
|
|
),
|
|
input_schema=input_schema,
|
|
output_schema=output_adapter.json_schema(),
|
|
handler=fn,
|
|
params=tuple(param_specs),
|
|
output_adapter=output_adapter,
|
|
)
|
|
fn.__a2a_skill__ = spec # type: ignore[attr-defined]
|
|
return fn
|
|
|
|
return decorator
|
|
|
|
|
|
def _is_run_context(tp: Any) -> bool:
|
|
"""True if ``tp`` is :class:`RunContext` or a parametrization of it."""
|
|
origin = typing.get_origin(tp) or tp
|
|
try:
|
|
return isinstance(origin, type) and issubclass(origin, RunContext)
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
class _AgentMeta(type):
|
|
def __new__(mcs, cls_name, bases, namespace):
|
|
cls = super().__new__(mcs, cls_name, bases, namespace)
|
|
skills: dict[str, SkillSpec] = {}
|
|
for base in bases:
|
|
skills.update(getattr(base, "_skills", {}))
|
|
for attr in namespace.values():
|
|
spec = getattr(attr, "__a2a_skill__", None)
|
|
if spec is None:
|
|
continue
|
|
if spec.name in skills and skills[spec.name].handler is not spec.handler:
|
|
# Allow overrides from the same chain (parent → child) but
|
|
# forbid two distinct handlers in the same class.
|
|
if any(
|
|
spec.name in getattr(b, "_skills", {})
|
|
and getattr(b, "_skills")[spec.name].handler is spec.handler
|
|
for b in bases
|
|
):
|
|
pass # legitimate override
|
|
else:
|
|
raise TypeError(
|
|
f"duplicate skill name {spec.name!r} in {cls_name}"
|
|
)
|
|
skills[spec.name] = spec
|
|
cls._skills = skills # type: ignore[attr-defined]
|
|
return cls
|
|
|
|
|
|
class A2AAgent(Generic[ConfigT, AuthT], metaclass=_AgentMeta):
|
|
"""Base class for A2A agents.
|
|
|
|
Subclasses declare:
|
|
|
|
- ``name``, ``description`` (and optional ``version``),
|
|
- optional ``config_model`` / ``auth_model`` (default to empty / NoAuth),
|
|
- deployment metadata: ``required_secrets``, ``required_env``,
|
|
``capabilities``, ``input_modes``, ``output_modes``,
|
|
- one or more methods decorated with :func:`skill`.
|
|
"""
|
|
|
|
name: ClassVar[str] = ""
|
|
description: ClassVar[str] = ""
|
|
version: ClassVar[str] = "0.1.0"
|
|
|
|
config_model: ClassVar[type[BaseModel]] = _EmptyConfig
|
|
auth_model: ClassVar[type[BaseModel]] = NoAuth
|
|
|
|
required_secrets: ClassVar[tuple[str, ...]] = ()
|
|
required_env: ClassVar[tuple[str, ...]] = ()
|
|
capabilities: ClassVar[dict[str, Any]] = {}
|
|
input_modes: ClassVar[tuple[str, ...]] = ("application/json",)
|
|
output_modes: ClassVar[tuple[str, ...]] = ("application/json",)
|
|
|
|
# --- runtime / deployment declaration (read by the platform deployer) ---
|
|
# Sandbox is always microsandbox; not exposed as a knob.
|
|
lifecycle: ClassVar[Lifecycle] = Lifecycle.EPHEMERAL
|
|
state: ClassVar[State] = State.NONE
|
|
state_model: ClassVar[type[BaseModel] | None] = None
|
|
resources: ClassVar[Resources] = Resources()
|
|
concurrency: ClassVar[int] = 1
|
|
egress: ClassVar[EgressPolicy] = EgressPolicy()
|
|
tools_used: ClassVar[tuple[str, ...]] = ()
|
|
workspace_access: ClassVar[WorkspaceAccess] = WorkspaceAccess.none()
|
|
|
|
_skills: ClassVar[dict[str, SkillSpec]] = {}
|
|
|
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
super().__init_subclass__(**kwargs)
|
|
if not cls.name:
|
|
raise TypeError(
|
|
f"{cls.__name__}.name must be set as a class attribute"
|
|
)
|
|
if cls.state is not State.NONE and cls.state_model is None:
|
|
raise TypeError(
|
|
f"{cls.__name__} declares state={cls.state.value!r} but "
|
|
f"state_model is not set"
|
|
)
|
|
if cls.lifecycle is Lifecycle.EPHEMERAL and cls.state is State.SESSION:
|
|
raise TypeError(
|
|
f"{cls.__name__}: lifecycle=ephemeral is incompatible with "
|
|
f"state=session"
|
|
)
|
|
|
|
@classmethod
|
|
def runtime(cls) -> AgentRuntime:
|
|
"""Aggregate the class-level runtime declaration.
|
|
|
|
``sandbox`` is always :attr:`Sandbox.MICROSANDBOX`; it is set here
|
|
rather than on the class so developers cannot weaken isolation.
|
|
"""
|
|
return AgentRuntime(
|
|
lifecycle=cls.lifecycle,
|
|
state=cls.state,
|
|
sandbox=Sandbox.MICROSANDBOX,
|
|
resources=cls.resources,
|
|
concurrency=cls.concurrency,
|
|
egress=cls.egress,
|
|
tools_used=cls.tools_used,
|
|
)
|
|
|
|
def __init__(self, config: ConfigT | dict[str, Any] | None = None) -> None:
|
|
validated = type(self).config_model.model_validate(config or {})
|
|
self.config: ConfigT = typing.cast(ConfigT, validated)
|
|
|
|
@property
|
|
def skills(self) -> dict[str, SkillSpec]:
|
|
return type(self)._skills
|
|
|
|
async def startup(self, ctx: RunContext[AuthT]) -> None:
|
|
"""Called once before the first invocation. Override to set up state."""
|
|
|
|
async def shutdown(self, ctx: RunContext[AuthT]) -> None:
|
|
"""Called once before the agent process exits. Override to tear down."""
|
|
|
|
async def health(self) -> bool:
|
|
"""Lightweight liveness check. Override to add real probes."""
|
|
return True
|
|
|
|
async def invoke(
|
|
self,
|
|
skill_name: str,
|
|
ctx: RunContext[AuthT],
|
|
/,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Invoke a skill with caller-supplied kwargs.
|
|
|
|
Inputs are validated and coerced via the skill's pydantic schema.
|
|
Required scopes are enforced against ``ctx.auth`` before the handler
|
|
runs. The raw handler return value is returned (Python-typed).
|
|
"""
|
|
spec = self.skills.get(skill_name)
|
|
if spec is None:
|
|
raise SkillNotFound(skill_name)
|
|
|
|
ctx.require_scopes(spec.scopes)
|
|
|
|
try:
|
|
validated = self._validate_inputs(spec, kwargs)
|
|
except Exception as exc:
|
|
raise SkillInputError(
|
|
f"invalid input for skill {spec.name!r}: {exc}"
|
|
) from exc
|
|
|
|
try:
|
|
return await spec.handler(self, ctx, **validated)
|
|
except SkillInputError:
|
|
raise
|
|
except Exception as exc:
|
|
raise SkillInvocationError(
|
|
f"skill {spec.name!r} raised {type(exc).__name__}: {exc}"
|
|
) from exc
|
|
|
|
async def invoke_json(
|
|
self,
|
|
skill_name: str,
|
|
ctx: RunContext[AuthT],
|
|
payload: dict[str, Any],
|
|
) -> Any:
|
|
"""Runtime-facing invoke: takes JSON-shaped payload, returns JSON-shaped result."""
|
|
spec = self.skills.get(skill_name)
|
|
if spec is None:
|
|
raise SkillNotFound(skill_name)
|
|
result = await self.invoke(skill_name, ctx, **payload)
|
|
if spec.output_adapter is None:
|
|
return result
|
|
return spec.output_adapter.dump_python(result, mode="json")
|
|
|
|
async def local_invoke(
|
|
self,
|
|
skill_name: str,
|
|
/,
|
|
*,
|
|
auth: AuthT | None = None,
|
|
secrets: dict[str, str] | None = None,
|
|
task_id: str = "local-task",
|
|
workspace: Any = None, # WorkspaceClient or None
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Convenience harness: build a :class:`LocalRunContext` and invoke.
|
|
|
|
Useful in tests and notebooks. ``auth`` defaults to a default-constructed
|
|
instance of the agent's ``auth_model`` (works for :class:`NoAuth`; pass
|
|
an explicit instance for auth models with required fields). Pass
|
|
``workspace=`` to bind a :class:`WorkspaceClient`.
|
|
"""
|
|
if auth is None:
|
|
auth = typing.cast(AuthT, type(self).auth_model())
|
|
ctx: LocalRunContext[AuthT] = LocalRunContext(
|
|
auth=auth, secrets=secrets, task_id=task_id, workspace=workspace
|
|
)
|
|
return await self.invoke(skill_name, ctx, **kwargs)
|
|
|
|
def card(self) -> AgentCard:
|
|
return AgentCard.from_agent(self)
|
|
|
|
@staticmethod
|
|
def _validate_inputs(
|
|
spec: SkillSpec, kwargs: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
known = {p.name for p in spec.params}
|
|
unknown = set(kwargs) - known
|
|
if unknown:
|
|
raise ValueError(f"unknown parameters: {sorted(unknown)}")
|
|
|
|
validated: dict[str, Any] = {}
|
|
for p in spec.params:
|
|
if p.name in kwargs:
|
|
validated[p.name] = p.adapter.validate_python(kwargs[p.name])
|
|
elif not p.has_default:
|
|
raise ValueError(f"missing required parameter: {p.name!r}")
|
|
# else: omit so the handler's own default applies
|
|
return validated
|