469 lines
13 KiB
Python
469 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from a2a_pack import (
|
|
A2AAgent,
|
|
AgentCard,
|
|
APIKeyAuth,
|
|
EgressPolicy,
|
|
JWTAuth,
|
|
Lifecycle,
|
|
LocalRunContext,
|
|
MissingScopes,
|
|
NoAuth,
|
|
Resources,
|
|
RunContext,
|
|
Sandbox,
|
|
SkillInputError,
|
|
SkillInvocationError,
|
|
SkillNotFound,
|
|
State,
|
|
skill,
|
|
)
|
|
|
|
|
|
class _GreeterConfig(BaseModel):
|
|
suffix: str = "!"
|
|
|
|
|
|
class _Greeter(A2AAgent[_GreeterConfig, NoAuth]):
|
|
name = "greeter"
|
|
description = "Says hi"
|
|
config_model = _GreeterConfig
|
|
auth_model = NoAuth
|
|
|
|
@skill(description="Greet someone")
|
|
async def greet(self, ctx: RunContext[NoAuth], who: str, loud: bool = False) -> str:
|
|
await ctx.emit_progress(f"greeting {who}")
|
|
out = f"hello {who}{self.config.suffix}"
|
|
return out.upper() if loud else out
|
|
|
|
@skill(name="boom", description="Fails on purpose")
|
|
async def _boom(self, ctx: RunContext[NoAuth]) -> str:
|
|
raise ValueError("nope")
|
|
|
|
|
|
def _ctx() -> LocalRunContext[NoAuth]:
|
|
return LocalRunContext(auth=NoAuth())
|
|
|
|
|
|
# --- subclass / decorator validation ---
|
|
|
|
def test_subclass_without_name_rejected():
|
|
with pytest.raises(TypeError, match="name must be set"):
|
|
|
|
class _Bad(A2AAgent):
|
|
description = "missing name"
|
|
|
|
|
|
def test_skill_requires_async():
|
|
with pytest.raises(TypeError, match="async"):
|
|
|
|
class _Sync(A2AAgent):
|
|
name = "sync"
|
|
|
|
@skill(description="sync handler")
|
|
def hi(self, ctx: RunContext[NoAuth]) -> str: # type: ignore[misc]
|
|
return "hi"
|
|
|
|
|
|
def test_skill_requires_run_context_param():
|
|
with pytest.raises(TypeError, match="RunContext"):
|
|
|
|
class _NoCtx(A2AAgent):
|
|
name = "noctx"
|
|
|
|
@skill(description="missing ctx")
|
|
async def hi(self, who: str) -> str:
|
|
return who
|
|
|
|
|
|
def test_skill_rejects_var_args():
|
|
with pytest.raises(TypeError, match=r"\*args"):
|
|
|
|
class _Va(A2AAgent):
|
|
name = "va"
|
|
|
|
@skill()
|
|
async def hi(self, ctx: RunContext[NoAuth], *args: str) -> str:
|
|
return ",".join(args)
|
|
|
|
|
|
def test_skill_rejects_var_kwargs():
|
|
with pytest.raises(TypeError, match=r"\*\*kwargs"):
|
|
|
|
class _Vk(A2AAgent):
|
|
name = "vk"
|
|
|
|
@skill()
|
|
async def hi(self, ctx: RunContext[NoAuth], **kwargs: str) -> str:
|
|
return str(kwargs)
|
|
|
|
|
|
def test_skill_rejects_untyped_param():
|
|
with pytest.raises(TypeError, match="missing a type annotation"):
|
|
|
|
class _U(A2AAgent):
|
|
name = "u"
|
|
|
|
@skill()
|
|
async def hi(self, ctx: RunContext[NoAuth], who) -> str: # type: ignore[no-untyped-def]
|
|
return who
|
|
|
|
|
|
def test_skill_rejects_reserved_name():
|
|
with pytest.raises(TypeError, match="reserved"):
|
|
|
|
class _R(A2AAgent):
|
|
name = "r"
|
|
|
|
@skill()
|
|
async def hi(self, ctx: RunContext[NoAuth], context: str) -> str:
|
|
return context
|
|
|
|
|
|
def test_duplicate_skill_name_rejected():
|
|
with pytest.raises(TypeError, match="duplicate skill name"):
|
|
|
|
class _Dup(A2AAgent):
|
|
name = "dup"
|
|
|
|
@skill(name="x")
|
|
async def a(self, ctx: RunContext[NoAuth]) -> str:
|
|
return "a"
|
|
|
|
@skill(name="x")
|
|
async def b(self, ctx: RunContext[NoAuth]) -> str:
|
|
return "b"
|
|
|
|
|
|
# --- card / metadata ---
|
|
|
|
def test_skills_collected_with_metadata():
|
|
g = _Greeter()
|
|
assert set(g.skills) == {"greet", "boom"}
|
|
assert g.skills["greet"].description == "Greet someone"
|
|
|
|
|
|
def test_card_omits_ctx_param():
|
|
card = _Greeter().card()
|
|
assert isinstance(card, AgentCard)
|
|
greet = next(s for s in card.skills if s.name == "greet")
|
|
assert "ctx" not in greet.input_schema["properties"]
|
|
assert greet.input_schema["required"] == ["who"]
|
|
assert greet.input_schema["additionalProperties"] is False
|
|
|
|
|
|
def test_card_includes_deploy_metadata():
|
|
class _Cfg(BaseModel):
|
|
pass
|
|
|
|
class _Meta(A2AAgent[_Cfg, NoAuth]):
|
|
name = "meta"
|
|
description = "metadata showcase"
|
|
required_secrets = ("OPENAI_KEY",)
|
|
required_env = ("REGION",)
|
|
capabilities = {"streaming": True}
|
|
input_modes = ("application/json", "text/plain")
|
|
output_modes = ("application/json",)
|
|
|
|
@skill()
|
|
async def noop(self, ctx: RunContext[NoAuth]) -> str:
|
|
return "ok"
|
|
|
|
card = _Meta().card()
|
|
assert card.required_secrets == ["OPENAI_KEY"]
|
|
assert card.required_env == ["REGION"]
|
|
assert card.capabilities == {"streaming": True}
|
|
assert card.input_modes == ["application/json", "text/plain"]
|
|
|
|
|
|
# --- config hydration ---
|
|
|
|
def test_default_config_constructed_when_none():
|
|
g = _Greeter()
|
|
assert g.config.suffix == "!"
|
|
|
|
|
|
def test_explicit_config_used():
|
|
g = _Greeter(_GreeterConfig(suffix="?!"))
|
|
assert g.config.suffix == "?!"
|
|
|
|
|
|
def test_config_accepts_dict():
|
|
g = _Greeter({"suffix": "?"})
|
|
assert g.config.suffix == "?"
|
|
|
|
|
|
def test_config_dict_validation_errors_propagate():
|
|
with pytest.raises(Exception):
|
|
_Greeter({"suffix": 123, "extra": True}) # type: ignore[arg-type]
|
|
|
|
|
|
# --- invocation ---
|
|
|
|
async def test_invoke_passes_ctx_and_returns_value():
|
|
g = _Greeter()
|
|
ctx = _ctx()
|
|
assert await g.invoke("greet", ctx, who="bob") == "hello bob!"
|
|
assert any(e.kind == "progress" for e in ctx.events)
|
|
|
|
|
|
async def test_invoke_validates_input_types():
|
|
g = _Greeter()
|
|
with pytest.raises(SkillInputError):
|
|
await g.invoke("greet", _ctx(), who=123) # type: ignore[arg-type]
|
|
|
|
|
|
async def test_invoke_rejects_unknown_param():
|
|
g = _Greeter()
|
|
with pytest.raises(SkillInputError, match="unknown parameter"):
|
|
await g.invoke("greet", _ctx(), who="bob", extra="nope")
|
|
|
|
|
|
async def test_invoke_missing_required_param():
|
|
g = _Greeter()
|
|
with pytest.raises(SkillInputError, match="missing required"):
|
|
await g.invoke("greet", _ctx())
|
|
|
|
|
|
async def test_invoke_unknown_skill_raises():
|
|
g = _Greeter()
|
|
with pytest.raises(SkillNotFound):
|
|
await g.invoke("nope", _ctx())
|
|
|
|
|
|
async def test_invoke_handler_error_wrapped():
|
|
g = _Greeter()
|
|
with pytest.raises(SkillInvocationError, match="ValueError"):
|
|
await g.invoke("boom", _ctx())
|
|
|
|
|
|
class _Pair(BaseModel):
|
|
a: int
|
|
b: int
|
|
|
|
|
|
class _PairCfg(BaseModel):
|
|
pass
|
|
|
|
|
|
class _PairAgent(A2AAgent[_PairCfg, NoAuth]):
|
|
name = "json-out"
|
|
description = ""
|
|
|
|
@skill()
|
|
async def make(self, ctx: RunContext[NoAuth]) -> _Pair:
|
|
return _Pair(a=1, b=2)
|
|
|
|
|
|
async def test_invoke_json_returns_serializable():
|
|
out = await _PairAgent().invoke_json("make", LocalRunContext(auth=NoAuth()), {})
|
|
assert out == {"a": 1, "b": 2}
|
|
|
|
|
|
# --- scopes ---
|
|
|
|
async def test_scope_enforcement_blocks_caller():
|
|
class _Cfg(BaseModel):
|
|
pass
|
|
|
|
class _S(A2AAgent[_Cfg, JWTAuth]):
|
|
name = "scoped-agent"
|
|
description = ""
|
|
auth_model = JWTAuth
|
|
|
|
@skill(scopes=["admin"])
|
|
async def secret_op(self, ctx: RunContext[JWTAuth]) -> str:
|
|
return "ok"
|
|
|
|
bad = LocalRunContext(auth=JWTAuth(sub="alice", scopes=["read"]))
|
|
with pytest.raises(MissingScopes):
|
|
await _S().invoke("secret_op", bad)
|
|
|
|
good = LocalRunContext(auth=JWTAuth(sub="alice", scopes=["admin"]))
|
|
assert await _S().invoke("secret_op", good) == "ok"
|
|
|
|
|
|
async def test_scope_enforcement_allows_no_scope_skill():
|
|
g = _Greeter()
|
|
assert await g.invoke("greet", _ctx(), who="bob") == "hello bob!"
|
|
|
|
|
|
# --- streaming helpers ---
|
|
|
|
async def test_stream_helpers_emit_typed_events():
|
|
ctx = _ctx()
|
|
await ctx.emit_text_delta("hello ")
|
|
await ctx.emit_text_delta("world")
|
|
await ctx.emit_error("uh oh", code="E001")
|
|
kinds = [e.kind for e in ctx.events]
|
|
assert kinds == ["text_delta", "text_delta", "error"]
|
|
assert ctx.events[-1].payload == {"message": "uh oh", "code": "E001"}
|
|
|
|
|
|
# --- health ---
|
|
|
|
async def test_default_health_is_true():
|
|
assert await _Greeter().health() is True
|
|
|
|
|
|
# --- local_invoke ---
|
|
|
|
async def test_local_invoke_default_no_auth():
|
|
g = _Greeter()
|
|
assert await g.local_invoke("greet", who="bob") == "hello bob!"
|
|
|
|
|
|
async def test_local_invoke_with_explicit_auth():
|
|
class _Cfg(BaseModel):
|
|
pass
|
|
|
|
class _A(A2AAgent[_Cfg, JWTAuth]):
|
|
name = "auth-test"
|
|
description = ""
|
|
auth_model = JWTAuth
|
|
|
|
@skill()
|
|
async def whoami(self, ctx: RunContext[JWTAuth]) -> str:
|
|
return f"{ctx.auth.sub}@{ctx.auth.org_id}"
|
|
|
|
out = await _A().local_invoke(
|
|
"whoami", auth=JWTAuth(sub="alice", org_id="acme")
|
|
)
|
|
assert out == "alice@acme"
|
|
|
|
|
|
# --- artifacts / inheritance ---
|
|
|
|
async def test_artifact_round_trip():
|
|
class _Cfg(BaseModel):
|
|
pass
|
|
|
|
class _A(A2AAgent[_Cfg, NoAuth]):
|
|
name = "art"
|
|
description = ""
|
|
|
|
@skill()
|
|
async def write(self, ctx: RunContext[NoAuth], body: str) -> str:
|
|
ref = await ctx.write_artifact("note.txt", body.encode(), "text/plain")
|
|
await ctx.emit_artifact(ref)
|
|
return ref.uri
|
|
|
|
ctx = _ctx()
|
|
uri = await _A().invoke("write", ctx, body="hello")
|
|
assert uri.startswith("memory://")
|
|
assert ctx.artifacts["note.txt"] == b"hello"
|
|
assert any(e.kind == "artifact" for e in ctx.events)
|
|
|
|
|
|
def test_skill_inheritance_preserves_parent_skills():
|
|
class _Loud(_Greeter):
|
|
name = "loud-greeter"
|
|
|
|
@skill(description="Shout")
|
|
async def shout(self, ctx: RunContext[NoAuth], what: str) -> str:
|
|
return what.upper()
|
|
|
|
skills = _Loud().skills
|
|
assert set(skills) == {"greet", "boom", "shout"}
|
|
|
|
|
|
# --- runtime metadata ---
|
|
|
|
|
|
class _SessionState(BaseModel):
|
|
history: list[str] = []
|
|
|
|
|
|
def test_default_runtime_is_ephemeral_no_state_microsandbox():
|
|
rt = _Greeter.runtime()
|
|
assert rt.lifecycle is Lifecycle.EPHEMERAL
|
|
assert rt.state is State.NONE
|
|
assert rt.sandbox is Sandbox.MICROSANDBOX # safe-by-default
|
|
assert rt.concurrency == 1
|
|
|
|
|
|
def test_state_requires_state_model():
|
|
with pytest.raises(TypeError, match="state_model"):
|
|
|
|
class _Bad(A2AAgent):
|
|
name = "bad"
|
|
description = ""
|
|
state = State.SESSION
|
|
|
|
|
|
def test_ephemeral_lifecycle_incompatible_with_session_state():
|
|
with pytest.raises(TypeError, match="ephemeral.*session"):
|
|
|
|
class _Bad(A2AAgent):
|
|
name = "bad"
|
|
description = ""
|
|
lifecycle = Lifecycle.EPHEMERAL
|
|
state = State.SESSION
|
|
state_model = _SessionState
|
|
|
|
|
|
def test_runtime_metadata_propagates_to_card():
|
|
class _ChatCfg(BaseModel):
|
|
pass
|
|
|
|
class _Chat(A2AAgent[_ChatCfg, JWTAuth]):
|
|
name = "chat"
|
|
description = "stateful chat agent"
|
|
auth_model = JWTAuth
|
|
lifecycle = Lifecycle.SESSION
|
|
state = State.SESSION
|
|
state_model = _SessionState
|
|
resources = Resources(cpu="2", memory="4Gi", gpu=0, max_runtime_seconds=1800)
|
|
concurrency = 4
|
|
egress = EgressPolicy(
|
|
allow_hosts=("api.openai.com",),
|
|
allow_internal_services=("litellm.llm.svc.cluster.local",),
|
|
)
|
|
tools_used = ("litellm", "minio")
|
|
|
|
@skill(timeout_seconds=30, idempotent=True, max_retries=2, cost_class="cheap")
|
|
async def reply(self, ctx: RunContext[JWTAuth], message: str) -> str:
|
|
return f"echo {message}"
|
|
|
|
card = _Chat().card()
|
|
assert card.runtime.lifecycle is Lifecycle.SESSION
|
|
assert card.runtime.state is State.SESSION
|
|
assert card.runtime.sandbox is Sandbox.MICROSANDBOX
|
|
assert card.runtime.resources.cpu == "2"
|
|
assert card.runtime.concurrency == 4
|
|
assert card.runtime.egress.allow_hosts == ("api.openai.com",)
|
|
assert card.runtime.tools_used == ("litellm", "minio")
|
|
assert card.state_schema is not None
|
|
assert "history" in card.state_schema["properties"]
|
|
|
|
skill_card = card.skills[0]
|
|
assert skill_card.policy.timeout_seconds == 30
|
|
assert skill_card.policy.idempotent is True
|
|
assert skill_card.policy.max_retries == 2
|
|
assert skill_card.policy.cost_class == "cheap"
|
|
|
|
|
|
def test_skill_metadata_propagates_to_card():
|
|
class _ScopedConfig(BaseModel):
|
|
pass
|
|
|
|
class _Scoped(A2AAgent[_ScopedConfig, APIKeyAuth]):
|
|
name = "scoped"
|
|
description = "scope test"
|
|
config_model = _ScopedConfig
|
|
auth_model = APIKeyAuth
|
|
|
|
@skill(scopes=["a:read", "a:write"], stream=True, tags=["x"])
|
|
async def do(self, ctx: RunContext[APIKeyAuth]) -> str:
|
|
return "ok"
|
|
|
|
card = _Scoped().card()
|
|
s = card.skills[0]
|
|
assert s.scopes == ["a:read", "a:write"]
|
|
assert s.stream is True
|
|
assert s.tags == ["x"]
|