Files
a2a/tests/test_agent.py
2026-05-08 21:59:51 -03:00

469 lines
13 KiB
Python

from __future__ import annotations
import pytest
from pydantic import BaseModel
from a2a_pack import (
A2AAgent,
AgentCard,
APIKeyAuth,
EgressPolicy,
JWTAuth,
Lifecycle,
LocalRunContext,
MissingScopes,
NoAuth,
Resources,
RunContext,
Sandbox,
SkillInputError,
SkillInvocationError,
SkillNotFound,
State,
skill,
)
class _GreeterConfig(BaseModel):
suffix: str = "!"
class _Greeter(A2AAgent[_GreeterConfig, NoAuth]):
name = "greeter"
description = "Says hi"
config_model = _GreeterConfig
auth_model = NoAuth
@skill(description="Greet someone")
async def greet(self, ctx: RunContext[NoAuth], who: str, loud: bool = False) -> str:
await ctx.emit_progress(f"greeting {who}")
out = f"hello {who}{self.config.suffix}"
return out.upper() if loud else out
@skill(name="boom", description="Fails on purpose")
async def _boom(self, ctx: RunContext[NoAuth]) -> str:
raise ValueError("nope")
def _ctx() -> LocalRunContext[NoAuth]:
return LocalRunContext(auth=NoAuth())
# --- subclass / decorator validation ---
def test_subclass_without_name_rejected():
with pytest.raises(TypeError, match="name must be set"):
class _Bad(A2AAgent):
description = "missing name"
def test_skill_requires_async():
with pytest.raises(TypeError, match="async"):
class _Sync(A2AAgent):
name = "sync"
@skill(description="sync handler")
def hi(self, ctx: RunContext[NoAuth]) -> str: # type: ignore[misc]
return "hi"
def test_skill_requires_run_context_param():
with pytest.raises(TypeError, match="RunContext"):
class _NoCtx(A2AAgent):
name = "noctx"
@skill(description="missing ctx")
async def hi(self, who: str) -> str:
return who
def test_skill_rejects_var_args():
with pytest.raises(TypeError, match=r"\*args"):
class _Va(A2AAgent):
name = "va"
@skill()
async def hi(self, ctx: RunContext[NoAuth], *args: str) -> str:
return ",".join(args)
def test_skill_rejects_var_kwargs():
with pytest.raises(TypeError, match=r"\*\*kwargs"):
class _Vk(A2AAgent):
name = "vk"
@skill()
async def hi(self, ctx: RunContext[NoAuth], **kwargs: str) -> str:
return str(kwargs)
def test_skill_rejects_untyped_param():
with pytest.raises(TypeError, match="missing a type annotation"):
class _U(A2AAgent):
name = "u"
@skill()
async def hi(self, ctx: RunContext[NoAuth], who) -> str: # type: ignore[no-untyped-def]
return who
def test_skill_rejects_reserved_name():
with pytest.raises(TypeError, match="reserved"):
class _R(A2AAgent):
name = "r"
@skill()
async def hi(self, ctx: RunContext[NoAuth], context: str) -> str:
return context
def test_duplicate_skill_name_rejected():
with pytest.raises(TypeError, match="duplicate skill name"):
class _Dup(A2AAgent):
name = "dup"
@skill(name="x")
async def a(self, ctx: RunContext[NoAuth]) -> str:
return "a"
@skill(name="x")
async def b(self, ctx: RunContext[NoAuth]) -> str:
return "b"
# --- card / metadata ---
def test_skills_collected_with_metadata():
g = _Greeter()
assert set(g.skills) == {"greet", "boom"}
assert g.skills["greet"].description == "Greet someone"
def test_card_omits_ctx_param():
card = _Greeter().card()
assert isinstance(card, AgentCard)
greet = next(s for s in card.skills if s.name == "greet")
assert "ctx" not in greet.input_schema["properties"]
assert greet.input_schema["required"] == ["who"]
assert greet.input_schema["additionalProperties"] is False
def test_card_includes_deploy_metadata():
class _Cfg(BaseModel):
pass
class _Meta(A2AAgent[_Cfg, NoAuth]):
name = "meta"
description = "metadata showcase"
required_secrets = ("OPENAI_KEY",)
required_env = ("REGION",)
capabilities = {"streaming": True}
input_modes = ("application/json", "text/plain")
output_modes = ("application/json",)
@skill()
async def noop(self, ctx: RunContext[NoAuth]) -> str:
return "ok"
card = _Meta().card()
assert card.required_secrets == ["OPENAI_KEY"]
assert card.required_env == ["REGION"]
assert card.capabilities == {"streaming": True}
assert card.input_modes == ["application/json", "text/plain"]
# --- config hydration ---
def test_default_config_constructed_when_none():
g = _Greeter()
assert g.config.suffix == "!"
def test_explicit_config_used():
g = _Greeter(_GreeterConfig(suffix="?!"))
assert g.config.suffix == "?!"
def test_config_accepts_dict():
g = _Greeter({"suffix": "?"})
assert g.config.suffix == "?"
def test_config_dict_validation_errors_propagate():
with pytest.raises(Exception):
_Greeter({"suffix": 123, "extra": True}) # type: ignore[arg-type]
# --- invocation ---
async def test_invoke_passes_ctx_and_returns_value():
g = _Greeter()
ctx = _ctx()
assert await g.invoke("greet", ctx, who="bob") == "hello bob!"
assert any(e.kind == "progress" for e in ctx.events)
async def test_invoke_validates_input_types():
g = _Greeter()
with pytest.raises(SkillInputError):
await g.invoke("greet", _ctx(), who=123) # type: ignore[arg-type]
async def test_invoke_rejects_unknown_param():
g = _Greeter()
with pytest.raises(SkillInputError, match="unknown parameter"):
await g.invoke("greet", _ctx(), who="bob", extra="nope")
async def test_invoke_missing_required_param():
g = _Greeter()
with pytest.raises(SkillInputError, match="missing required"):
await g.invoke("greet", _ctx())
async def test_invoke_unknown_skill_raises():
g = _Greeter()
with pytest.raises(SkillNotFound):
await g.invoke("nope", _ctx())
async def test_invoke_handler_error_wrapped():
g = _Greeter()
with pytest.raises(SkillInvocationError, match="ValueError"):
await g.invoke("boom", _ctx())
class _Pair(BaseModel):
a: int
b: int
class _PairCfg(BaseModel):
pass
class _PairAgent(A2AAgent[_PairCfg, NoAuth]):
name = "json-out"
description = ""
@skill()
async def make(self, ctx: RunContext[NoAuth]) -> _Pair:
return _Pair(a=1, b=2)
async def test_invoke_json_returns_serializable():
out = await _PairAgent().invoke_json("make", LocalRunContext(auth=NoAuth()), {})
assert out == {"a": 1, "b": 2}
# --- scopes ---
async def test_scope_enforcement_blocks_caller():
class _Cfg(BaseModel):
pass
class _S(A2AAgent[_Cfg, JWTAuth]):
name = "scoped-agent"
description = ""
auth_model = JWTAuth
@skill(scopes=["admin"])
async def secret_op(self, ctx: RunContext[JWTAuth]) -> str:
return "ok"
bad = LocalRunContext(auth=JWTAuth(sub="alice", scopes=["read"]))
with pytest.raises(MissingScopes):
await _S().invoke("secret_op", bad)
good = LocalRunContext(auth=JWTAuth(sub="alice", scopes=["admin"]))
assert await _S().invoke("secret_op", good) == "ok"
async def test_scope_enforcement_allows_no_scope_skill():
g = _Greeter()
assert await g.invoke("greet", _ctx(), who="bob") == "hello bob!"
# --- streaming helpers ---
async def test_stream_helpers_emit_typed_events():
ctx = _ctx()
await ctx.emit_text_delta("hello ")
await ctx.emit_text_delta("world")
await ctx.emit_error("uh oh", code="E001")
kinds = [e.kind for e in ctx.events]
assert kinds == ["text_delta", "text_delta", "error"]
assert ctx.events[-1].payload == {"message": "uh oh", "code": "E001"}
# --- health ---
async def test_default_health_is_true():
assert await _Greeter().health() is True
# --- local_invoke ---
async def test_local_invoke_default_no_auth():
g = _Greeter()
assert await g.local_invoke("greet", who="bob") == "hello bob!"
async def test_local_invoke_with_explicit_auth():
class _Cfg(BaseModel):
pass
class _A(A2AAgent[_Cfg, JWTAuth]):
name = "auth-test"
description = ""
auth_model = JWTAuth
@skill()
async def whoami(self, ctx: RunContext[JWTAuth]) -> str:
return f"{ctx.auth.sub}@{ctx.auth.org_id}"
out = await _A().local_invoke(
"whoami", auth=JWTAuth(sub="alice", org_id="acme")
)
assert out == "alice@acme"
# --- artifacts / inheritance ---
async def test_artifact_round_trip():
class _Cfg(BaseModel):
pass
class _A(A2AAgent[_Cfg, NoAuth]):
name = "art"
description = ""
@skill()
async def write(self, ctx: RunContext[NoAuth], body: str) -> str:
ref = await ctx.write_artifact("note.txt", body.encode(), "text/plain")
await ctx.emit_artifact(ref)
return ref.uri
ctx = _ctx()
uri = await _A().invoke("write", ctx, body="hello")
assert uri.startswith("memory://")
assert ctx.artifacts["note.txt"] == b"hello"
assert any(e.kind == "artifact" for e in ctx.events)
def test_skill_inheritance_preserves_parent_skills():
class _Loud(_Greeter):
name = "loud-greeter"
@skill(description="Shout")
async def shout(self, ctx: RunContext[NoAuth], what: str) -> str:
return what.upper()
skills = _Loud().skills
assert set(skills) == {"greet", "boom", "shout"}
# --- runtime metadata ---
class _SessionState(BaseModel):
history: list[str] = []
def test_default_runtime_is_ephemeral_no_state_microsandbox():
rt = _Greeter.runtime()
assert rt.lifecycle is Lifecycle.EPHEMERAL
assert rt.state is State.NONE
assert rt.sandbox is Sandbox.MICROSANDBOX # safe-by-default
assert rt.concurrency == 1
def test_state_requires_state_model():
with pytest.raises(TypeError, match="state_model"):
class _Bad(A2AAgent):
name = "bad"
description = ""
state = State.SESSION
def test_ephemeral_lifecycle_incompatible_with_session_state():
with pytest.raises(TypeError, match="ephemeral.*session"):
class _Bad(A2AAgent):
name = "bad"
description = ""
lifecycle = Lifecycle.EPHEMERAL
state = State.SESSION
state_model = _SessionState
def test_runtime_metadata_propagates_to_card():
class _ChatCfg(BaseModel):
pass
class _Chat(A2AAgent[_ChatCfg, JWTAuth]):
name = "chat"
description = "stateful chat agent"
auth_model = JWTAuth
lifecycle = Lifecycle.SESSION
state = State.SESSION
state_model = _SessionState
resources = Resources(cpu="2", memory="4Gi", gpu=0, max_runtime_seconds=1800)
concurrency = 4
egress = EgressPolicy(
allow_hosts=("api.openai.com",),
allow_internal_services=("litellm.llm.svc.cluster.local",),
)
tools_used = ("litellm", "minio")
@skill(timeout_seconds=30, idempotent=True, max_retries=2, cost_class="cheap")
async def reply(self, ctx: RunContext[JWTAuth], message: str) -> str:
return f"echo {message}"
card = _Chat().card()
assert card.runtime.lifecycle is Lifecycle.SESSION
assert card.runtime.state is State.SESSION
assert card.runtime.sandbox is Sandbox.MICROSANDBOX
assert card.runtime.resources.cpu == "2"
assert card.runtime.concurrency == 4
assert card.runtime.egress.allow_hosts == ("api.openai.com",)
assert card.runtime.tools_used == ("litellm", "minio")
assert card.state_schema is not None
assert "history" in card.state_schema["properties"]
skill_card = card.skills[0]
assert skill_card.policy.timeout_seconds == 30
assert skill_card.policy.idempotent is True
assert skill_card.policy.max_retries == 2
assert skill_card.policy.cost_class == "cheap"
def test_skill_metadata_propagates_to_card():
class _ScopedConfig(BaseModel):
pass
class _Scoped(A2AAgent[_ScopedConfig, APIKeyAuth]):
name = "scoped"
description = "scope test"
config_model = _ScopedConfig
auth_model = APIKeyAuth
@skill(scopes=["a:read", "a:write"], stream=True, tags=["x"])
async def do(self, ctx: RunContext[APIKeyAuth]) -> str:
return "ok"
card = _Scoped().card()
s = card.skills[0]
assert s.scopes == ["a:read", "a:write"]
assert s.stream is True
assert s.tags == ["x"]