initial a2a-pack
This commit is contained in:
410
a2a_pack/agent.py
Normal file
410
a2a_pack/agent.py
Normal file
@@ -0,0 +1,410 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable, ClassVar, Generic, Sequence, TypeVar
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from .auth import NoAuth
|
||||
from .card import AgentCard
|
||||
from .context import LocalRunContext, RunContext
|
||||
from .runtime import (
|
||||
AgentRuntime,
|
||||
EgressPolicy,
|
||||
Lifecycle,
|
||||
Resources,
|
||||
Sandbox,
|
||||
SkillPolicy,
|
||||
State,
|
||||
)
|
||||
from .workspace import WorkspaceAccess
|
||||
|
||||
ConfigT = TypeVar("ConfigT", bound=BaseModel)
|
||||
AuthT = TypeVar("AuthT", bound=BaseModel)
|
||||
|
||||
|
||||
_RESERVED_PARAM_NAMES = frozenset({"self", "ctx", "context"})
|
||||
|
||||
|
||||
class _EmptyConfig(BaseModel):
|
||||
"""Default config model when an agent declares no config."""
|
||||
|
||||
|
||||
class SkillNotFound(KeyError):
|
||||
"""Raised when invoke() is called with an unknown skill name."""
|
||||
|
||||
|
||||
class SkillInvocationError(RuntimeError):
|
||||
"""Raised when a skill handler raises during invoke()."""
|
||||
|
||||
|
||||
class SkillInputError(ValueError):
|
||||
"""Raised when invoke() inputs fail validation against the skill schema."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParamSpec:
|
||||
"""Validation metadata for a single skill parameter."""
|
||||
|
||||
name: str
|
||||
adapter: TypeAdapter[Any]
|
||||
has_default: bool
|
||||
default: Any = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SkillSpec:
|
||||
"""Static metadata about a single skill, captured at decoration time."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
tags: tuple[str, ...]
|
||||
scopes: tuple[str, ...]
|
||||
stream: bool
|
||||
policy: SkillPolicy
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
handler: Callable[..., Awaitable[Any]]
|
||||
params: tuple[ParamSpec, ...] = field(default_factory=tuple)
|
||||
output_adapter: TypeAdapter[Any] | None = None
|
||||
|
||||
|
||||
def skill(
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str = "",
|
||||
tags: Sequence[str] = (),
|
||||
scopes: Sequence[str] = (),
|
||||
stream: bool = False,
|
||||
timeout_seconds: float | None = None,
|
||||
idempotent: bool = False,
|
||||
max_retries: int = 0,
|
||||
cost_class: str | None = None,
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
|
||||
"""Mark an :class:`A2AAgent` method as a discoverable skill.
|
||||
|
||||
Conventions:
|
||||
|
||||
- The handler MUST be ``async def``.
|
||||
- Its first parameter (after ``self``) MUST be a :class:`RunContext`;
|
||||
the context is supplied by the runtime and is omitted from the
|
||||
published input schema.
|
||||
- Remaining parameters MUST be type-annotated. ``*args`` and ``**kwargs``
|
||||
are rejected.
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
|
||||
if not inspect.iscoroutinefunction(fn):
|
||||
raise TypeError(
|
||||
f"@skill requires an async function: {fn.__qualname__}"
|
||||
)
|
||||
|
||||
sig = inspect.signature(fn)
|
||||
hints = typing.get_type_hints(fn)
|
||||
params = list(sig.parameters.values())[1:] # drop self
|
||||
if not params:
|
||||
raise TypeError(
|
||||
f"@skill {fn.__qualname__}: missing RunContext parameter"
|
||||
)
|
||||
|
||||
ctx_param, *rest = params
|
||||
ctx_hint = hints.get(ctx_param.name)
|
||||
if ctx_hint is None or not _is_run_context(ctx_hint):
|
||||
raise TypeError(
|
||||
f"@skill {fn.__qualname__}: first arg after self must be "
|
||||
f"annotated as RunContext (got {ctx_hint!r})"
|
||||
)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
param_specs: list[ParamSpec] = []
|
||||
for p in rest:
|
||||
if p.kind is inspect.Parameter.VAR_POSITIONAL:
|
||||
raise TypeError(
|
||||
f"@skill {fn.__qualname__}: *{p.name} is not allowed"
|
||||
)
|
||||
if p.kind is inspect.Parameter.VAR_KEYWORD:
|
||||
raise TypeError(
|
||||
f"@skill {fn.__qualname__}: **{p.name} is not allowed"
|
||||
)
|
||||
if p.name in _RESERVED_PARAM_NAMES:
|
||||
raise TypeError(
|
||||
f"@skill {fn.__qualname__}: reserved param name {p.name!r}"
|
||||
)
|
||||
if p.name not in hints:
|
||||
raise TypeError(
|
||||
f"@skill {fn.__qualname__}: parameter {p.name!r} is "
|
||||
f"missing a type annotation"
|
||||
)
|
||||
tp = hints[p.name]
|
||||
adapter: TypeAdapter[Any] = TypeAdapter(tp)
|
||||
properties[p.name] = adapter.json_schema()
|
||||
has_default = p.default is not inspect.Parameter.empty
|
||||
if not has_default:
|
||||
required.append(p.name)
|
||||
param_specs.append(
|
||||
ParamSpec(
|
||||
name=p.name,
|
||||
adapter=adapter,
|
||||
has_default=has_default,
|
||||
default=None if not has_default else p.default,
|
||||
)
|
||||
)
|
||||
|
||||
input_schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
return_tp = hints.get("return", Any)
|
||||
output_adapter: TypeAdapter[Any] = TypeAdapter(return_tp)
|
||||
|
||||
spec = SkillSpec(
|
||||
name=name or fn.__name__,
|
||||
description=description,
|
||||
tags=tuple(tags),
|
||||
scopes=tuple(scopes),
|
||||
stream=stream,
|
||||
policy=SkillPolicy(
|
||||
timeout_seconds=timeout_seconds,
|
||||
idempotent=idempotent,
|
||||
max_retries=max_retries,
|
||||
cost_class=cost_class,
|
||||
),
|
||||
input_schema=input_schema,
|
||||
output_schema=output_adapter.json_schema(),
|
||||
handler=fn,
|
||||
params=tuple(param_specs),
|
||||
output_adapter=output_adapter,
|
||||
)
|
||||
fn.__a2a_skill__ = spec # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _is_run_context(tp: Any) -> bool:
|
||||
"""True if ``tp`` is :class:`RunContext` or a parametrization of it."""
|
||||
origin = typing.get_origin(tp) or tp
|
||||
try:
|
||||
return isinstance(origin, type) and issubclass(origin, RunContext)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
class _AgentMeta(type):
|
||||
def __new__(mcs, cls_name, bases, namespace):
|
||||
cls = super().__new__(mcs, cls_name, bases, namespace)
|
||||
skills: dict[str, SkillSpec] = {}
|
||||
for base in bases:
|
||||
skills.update(getattr(base, "_skills", {}))
|
||||
for attr in namespace.values():
|
||||
spec = getattr(attr, "__a2a_skill__", None)
|
||||
if spec is None:
|
||||
continue
|
||||
if spec.name in skills and skills[spec.name].handler is not spec.handler:
|
||||
# Allow overrides from the same chain (parent → child) but
|
||||
# forbid two distinct handlers in the same class.
|
||||
if any(
|
||||
spec.name in getattr(b, "_skills", {})
|
||||
and getattr(b, "_skills")[spec.name].handler is spec.handler
|
||||
for b in bases
|
||||
):
|
||||
pass # legitimate override
|
||||
else:
|
||||
raise TypeError(
|
||||
f"duplicate skill name {spec.name!r} in {cls_name}"
|
||||
)
|
||||
skills[spec.name] = spec
|
||||
cls._skills = skills # type: ignore[attr-defined]
|
||||
return cls
|
||||
|
||||
|
||||
class A2AAgent(Generic[ConfigT, AuthT], metaclass=_AgentMeta):
|
||||
"""Base class for A2A agents.
|
||||
|
||||
Subclasses declare:
|
||||
|
||||
- ``name``, ``description`` (and optional ``version``),
|
||||
- optional ``config_model`` / ``auth_model`` (default to empty / NoAuth),
|
||||
- deployment metadata: ``required_secrets``, ``required_env``,
|
||||
``capabilities``, ``input_modes``, ``output_modes``,
|
||||
- one or more methods decorated with :func:`skill`.
|
||||
"""
|
||||
|
||||
name: ClassVar[str] = ""
|
||||
description: ClassVar[str] = ""
|
||||
version: ClassVar[str] = "0.1.0"
|
||||
|
||||
config_model: ClassVar[type[BaseModel]] = _EmptyConfig
|
||||
auth_model: ClassVar[type[BaseModel]] = NoAuth
|
||||
|
||||
required_secrets: ClassVar[tuple[str, ...]] = ()
|
||||
required_env: ClassVar[tuple[str, ...]] = ()
|
||||
capabilities: ClassVar[dict[str, Any]] = {}
|
||||
input_modes: ClassVar[tuple[str, ...]] = ("application/json",)
|
||||
output_modes: ClassVar[tuple[str, ...]] = ("application/json",)
|
||||
|
||||
# --- runtime / deployment declaration (read by the platform deployer) ---
|
||||
# Sandbox is always microsandbox; not exposed as a knob.
|
||||
lifecycle: ClassVar[Lifecycle] = Lifecycle.EPHEMERAL
|
||||
state: ClassVar[State] = State.NONE
|
||||
state_model: ClassVar[type[BaseModel] | None] = None
|
||||
resources: ClassVar[Resources] = Resources()
|
||||
concurrency: ClassVar[int] = 1
|
||||
egress: ClassVar[EgressPolicy] = EgressPolicy()
|
||||
tools_used: ClassVar[tuple[str, ...]] = ()
|
||||
workspace_access: ClassVar[WorkspaceAccess] = WorkspaceAccess.none()
|
||||
|
||||
_skills: ClassVar[dict[str, SkillSpec]] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not cls.name:
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.name must be set as a class attribute"
|
||||
)
|
||||
if cls.state is not State.NONE and cls.state_model is None:
|
||||
raise TypeError(
|
||||
f"{cls.__name__} declares state={cls.state.value!r} but "
|
||||
f"state_model is not set"
|
||||
)
|
||||
if cls.lifecycle is Lifecycle.EPHEMERAL and cls.state is State.SESSION:
|
||||
raise TypeError(
|
||||
f"{cls.__name__}: lifecycle=ephemeral is incompatible with "
|
||||
f"state=session"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def runtime(cls) -> AgentRuntime:
|
||||
"""Aggregate the class-level runtime declaration.
|
||||
|
||||
``sandbox`` is always :attr:`Sandbox.MICROSANDBOX`; it is set here
|
||||
rather than on the class so developers cannot weaken isolation.
|
||||
"""
|
||||
return AgentRuntime(
|
||||
lifecycle=cls.lifecycle,
|
||||
state=cls.state,
|
||||
sandbox=Sandbox.MICROSANDBOX,
|
||||
resources=cls.resources,
|
||||
concurrency=cls.concurrency,
|
||||
egress=cls.egress,
|
||||
tools_used=cls.tools_used,
|
||||
)
|
||||
|
||||
def __init__(self, config: ConfigT | dict[str, Any] | None = None) -> None:
|
||||
validated = type(self).config_model.model_validate(config or {})
|
||||
self.config: ConfigT = typing.cast(ConfigT, validated)
|
||||
|
||||
@property
|
||||
def skills(self) -> dict[str, SkillSpec]:
|
||||
return type(self)._skills
|
||||
|
||||
async def startup(self, ctx: RunContext[AuthT]) -> None:
|
||||
"""Called once before the first invocation. Override to set up state."""
|
||||
|
||||
async def shutdown(self, ctx: RunContext[AuthT]) -> None:
|
||||
"""Called once before the agent process exits. Override to tear down."""
|
||||
|
||||
async def health(self) -> bool:
|
||||
"""Lightweight liveness check. Override to add real probes."""
|
||||
return True
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
skill_name: str,
|
||||
ctx: RunContext[AuthT],
|
||||
/,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Invoke a skill with caller-supplied kwargs.
|
||||
|
||||
Inputs are validated and coerced via the skill's pydantic schema.
|
||||
Required scopes are enforced against ``ctx.auth`` before the handler
|
||||
runs. The raw handler return value is returned (Python-typed).
|
||||
"""
|
||||
spec = self.skills.get(skill_name)
|
||||
if spec is None:
|
||||
raise SkillNotFound(skill_name)
|
||||
|
||||
ctx.require_scopes(spec.scopes)
|
||||
|
||||
try:
|
||||
validated = self._validate_inputs(spec, kwargs)
|
||||
except Exception as exc:
|
||||
raise SkillInputError(
|
||||
f"invalid input for skill {spec.name!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
return await spec.handler(self, ctx, **validated)
|
||||
except SkillInputError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise SkillInvocationError(
|
||||
f"skill {spec.name!r} raised {type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
async def invoke_json(
|
||||
self,
|
||||
skill_name: str,
|
||||
ctx: RunContext[AuthT],
|
||||
payload: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Runtime-facing invoke: takes JSON-shaped payload, returns JSON-shaped result."""
|
||||
spec = self.skills.get(skill_name)
|
||||
if spec is None:
|
||||
raise SkillNotFound(skill_name)
|
||||
result = await self.invoke(skill_name, ctx, **payload)
|
||||
if spec.output_adapter is None:
|
||||
return result
|
||||
return spec.output_adapter.dump_python(result, mode="json")
|
||||
|
||||
async def local_invoke(
|
||||
self,
|
||||
skill_name: str,
|
||||
/,
|
||||
*,
|
||||
auth: AuthT | None = None,
|
||||
secrets: dict[str, str] | None = None,
|
||||
task_id: str = "local-task",
|
||||
workspace: Any = None, # WorkspaceClient or None
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Convenience harness: build a :class:`LocalRunContext` and invoke.
|
||||
|
||||
Useful in tests and notebooks. ``auth`` defaults to a default-constructed
|
||||
instance of the agent's ``auth_model`` (works for :class:`NoAuth`; pass
|
||||
an explicit instance for auth models with required fields). Pass
|
||||
``workspace=`` to bind a :class:`WorkspaceClient`.
|
||||
"""
|
||||
if auth is None:
|
||||
auth = typing.cast(AuthT, type(self).auth_model())
|
||||
ctx: LocalRunContext[AuthT] = LocalRunContext(
|
||||
auth=auth, secrets=secrets, task_id=task_id, workspace=workspace
|
||||
)
|
||||
return await self.invoke(skill_name, ctx, **kwargs)
|
||||
|
||||
def card(self) -> AgentCard:
|
||||
return AgentCard.from_agent(self)
|
||||
|
||||
@staticmethod
|
||||
def _validate_inputs(
|
||||
spec: SkillSpec, kwargs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
known = {p.name for p in spec.params}
|
||||
unknown = set(kwargs) - known
|
||||
if unknown:
|
||||
raise ValueError(f"unknown parameters: {sorted(unknown)}")
|
||||
|
||||
validated: dict[str, Any] = {}
|
||||
for p in spec.params:
|
||||
if p.name in kwargs:
|
||||
validated[p.name] = p.adapter.validate_python(kwargs[p.name])
|
||||
elif not p.has_default:
|
||||
raise ValueError(f"missing required parameter: {p.name!r}")
|
||||
# else: omit so the handler's own default applies
|
||||
return validated
|
||||
Reference in New Issue
Block a user