from __future__ import annotations import pytest from pydantic import BaseModel from a2a_pack import ( A2AAgent, AgentCard, APIKeyAuth, EgressPolicy, JWTAuth, Lifecycle, LocalRunContext, MissingScopes, NoAuth, Resources, RunContext, Sandbox, SkillInputError, SkillInvocationError, SkillNotFound, State, skill, ) class _GreeterConfig(BaseModel): suffix: str = "!" class _Greeter(A2AAgent[_GreeterConfig, NoAuth]): name = "greeter" description = "Says hi" config_model = _GreeterConfig auth_model = NoAuth @skill(description="Greet someone") async def greet(self, ctx: RunContext[NoAuth], who: str, loud: bool = False) -> str: await ctx.emit_progress(f"greeting {who}") out = f"hello {who}{self.config.suffix}" return out.upper() if loud else out @skill(name="boom", description="Fails on purpose") async def _boom(self, ctx: RunContext[NoAuth]) -> str: raise ValueError("nope") def _ctx() -> LocalRunContext[NoAuth]: return LocalRunContext(auth=NoAuth()) # --- subclass / decorator validation --- def test_subclass_without_name_rejected(): with pytest.raises(TypeError, match="name must be set"): class _Bad(A2AAgent): description = "missing name" def test_skill_requires_async(): with pytest.raises(TypeError, match="async"): class _Sync(A2AAgent): name = "sync" @skill(description="sync handler") def hi(self, ctx: RunContext[NoAuth]) -> str: # type: ignore[misc] return "hi" def test_skill_requires_run_context_param(): with pytest.raises(TypeError, match="RunContext"): class _NoCtx(A2AAgent): name = "noctx" @skill(description="missing ctx") async def hi(self, who: str) -> str: return who def test_skill_rejects_var_args(): with pytest.raises(TypeError, match=r"\*args"): class _Va(A2AAgent): name = "va" @skill() async def hi(self, ctx: RunContext[NoAuth], *args: str) -> str: return ",".join(args) def test_skill_rejects_var_kwargs(): with pytest.raises(TypeError, match=r"\*\*kwargs"): class _Vk(A2AAgent): name = "vk" @skill() async def hi(self, ctx: RunContext[NoAuth], **kwargs: str) -> str: return str(kwargs) def test_skill_rejects_untyped_param(): with pytest.raises(TypeError, match="missing a type annotation"): class _U(A2AAgent): name = "u" @skill() async def hi(self, ctx: RunContext[NoAuth], who) -> str: # type: ignore[no-untyped-def] return who def test_skill_rejects_reserved_name(): with pytest.raises(TypeError, match="reserved"): class _R(A2AAgent): name = "r" @skill() async def hi(self, ctx: RunContext[NoAuth], context: str) -> str: return context def test_duplicate_skill_name_rejected(): with pytest.raises(TypeError, match="duplicate skill name"): class _Dup(A2AAgent): name = "dup" @skill(name="x") async def a(self, ctx: RunContext[NoAuth]) -> str: return "a" @skill(name="x") async def b(self, ctx: RunContext[NoAuth]) -> str: return "b" # --- card / metadata --- def test_skills_collected_with_metadata(): g = _Greeter() assert set(g.skills) == {"greet", "boom"} assert g.skills["greet"].description == "Greet someone" def test_card_omits_ctx_param(): card = _Greeter().card() assert isinstance(card, AgentCard) greet = next(s for s in card.skills if s.name == "greet") assert "ctx" not in greet.input_schema["properties"] assert greet.input_schema["required"] == ["who"] assert greet.input_schema["additionalProperties"] is False def test_card_includes_deploy_metadata(): class _Cfg(BaseModel): pass class _Meta(A2AAgent[_Cfg, NoAuth]): name = "meta" description = "metadata showcase" required_secrets = ("OPENAI_KEY",) required_env = ("REGION",) capabilities = {"streaming": True} input_modes = ("application/json", "text/plain") output_modes = ("application/json",) @skill() async def noop(self, ctx: RunContext[NoAuth]) -> str: return "ok" card = _Meta().card() assert card.required_secrets == ["OPENAI_KEY"] assert card.required_env == ["REGION"] assert card.capabilities == {"streaming": True} assert card.input_modes == ["application/json", "text/plain"] # --- config hydration --- def test_default_config_constructed_when_none(): g = _Greeter() assert g.config.suffix == "!" def test_explicit_config_used(): g = _Greeter(_GreeterConfig(suffix="?!")) assert g.config.suffix == "?!" def test_config_accepts_dict(): g = _Greeter({"suffix": "?"}) assert g.config.suffix == "?" def test_config_dict_validation_errors_propagate(): with pytest.raises(Exception): _Greeter({"suffix": 123, "extra": True}) # type: ignore[arg-type] # --- invocation --- async def test_invoke_passes_ctx_and_returns_value(): g = _Greeter() ctx = _ctx() assert await g.invoke("greet", ctx, who="bob") == "hello bob!" assert any(e.kind == "progress" for e in ctx.events) async def test_invoke_validates_input_types(): g = _Greeter() with pytest.raises(SkillInputError): await g.invoke("greet", _ctx(), who=123) # type: ignore[arg-type] async def test_invoke_rejects_unknown_param(): g = _Greeter() with pytest.raises(SkillInputError, match="unknown parameter"): await g.invoke("greet", _ctx(), who="bob", extra="nope") async def test_invoke_missing_required_param(): g = _Greeter() with pytest.raises(SkillInputError, match="missing required"): await g.invoke("greet", _ctx()) async def test_invoke_unknown_skill_raises(): g = _Greeter() with pytest.raises(SkillNotFound): await g.invoke("nope", _ctx()) async def test_invoke_handler_error_wrapped(): g = _Greeter() with pytest.raises(SkillInvocationError, match="ValueError"): await g.invoke("boom", _ctx()) class _Pair(BaseModel): a: int b: int class _PairCfg(BaseModel): pass class _PairAgent(A2AAgent[_PairCfg, NoAuth]): name = "json-out" description = "" @skill() async def make(self, ctx: RunContext[NoAuth]) -> _Pair: return _Pair(a=1, b=2) async def test_invoke_json_returns_serializable(): out = await _PairAgent().invoke_json("make", LocalRunContext(auth=NoAuth()), {}) assert out == {"a": 1, "b": 2} # --- scopes --- async def test_scope_enforcement_blocks_caller(): class _Cfg(BaseModel): pass class _S(A2AAgent[_Cfg, JWTAuth]): name = "scoped-agent" description = "" auth_model = JWTAuth @skill(scopes=["admin"]) async def secret_op(self, ctx: RunContext[JWTAuth]) -> str: return "ok" bad = LocalRunContext(auth=JWTAuth(sub="alice", scopes=["read"])) with pytest.raises(MissingScopes): await _S().invoke("secret_op", bad) good = LocalRunContext(auth=JWTAuth(sub="alice", scopes=["admin"])) assert await _S().invoke("secret_op", good) == "ok" async def test_scope_enforcement_allows_no_scope_skill(): g = _Greeter() assert await g.invoke("greet", _ctx(), who="bob") == "hello bob!" # --- streaming helpers --- async def test_stream_helpers_emit_typed_events(): ctx = _ctx() await ctx.emit_text_delta("hello ") await ctx.emit_text_delta("world") await ctx.emit_error("uh oh", code="E001") kinds = [e.kind for e in ctx.events] assert kinds == ["text_delta", "text_delta", "error"] assert ctx.events[-1].payload == {"message": "uh oh", "code": "E001"} # --- health --- async def test_default_health_is_true(): assert await _Greeter().health() is True # --- local_invoke --- async def test_local_invoke_default_no_auth(): g = _Greeter() assert await g.local_invoke("greet", who="bob") == "hello bob!" async def test_local_invoke_with_explicit_auth(): class _Cfg(BaseModel): pass class _A(A2AAgent[_Cfg, JWTAuth]): name = "auth-test" description = "" auth_model = JWTAuth @skill() async def whoami(self, ctx: RunContext[JWTAuth]) -> str: return f"{ctx.auth.sub}@{ctx.auth.org_id}" out = await _A().local_invoke( "whoami", auth=JWTAuth(sub="alice", org_id="acme") ) assert out == "alice@acme" # --- artifacts / inheritance --- async def test_artifact_round_trip(): class _Cfg(BaseModel): pass class _A(A2AAgent[_Cfg, NoAuth]): name = "art" description = "" @skill() async def write(self, ctx: RunContext[NoAuth], body: str) -> str: ref = await ctx.write_artifact("note.txt", body.encode(), "text/plain") await ctx.emit_artifact(ref) return ref.uri ctx = _ctx() uri = await _A().invoke("write", ctx, body="hello") assert uri.startswith("memory://") assert ctx.artifacts["note.txt"] == b"hello" assert any(e.kind == "artifact" for e in ctx.events) def test_skill_inheritance_preserves_parent_skills(): class _Loud(_Greeter): name = "loud-greeter" @skill(description="Shout") async def shout(self, ctx: RunContext[NoAuth], what: str) -> str: return what.upper() skills = _Loud().skills assert set(skills) == {"greet", "boom", "shout"} # --- runtime metadata --- class _SessionState(BaseModel): history: list[str] = [] def test_default_runtime_is_ephemeral_no_state_microsandbox(): rt = _Greeter.runtime() assert rt.lifecycle is Lifecycle.EPHEMERAL assert rt.state is State.NONE assert rt.sandbox is Sandbox.MICROSANDBOX # safe-by-default assert rt.concurrency == 1 def test_state_requires_state_model(): with pytest.raises(TypeError, match="state_model"): class _Bad(A2AAgent): name = "bad" description = "" state = State.SESSION def test_ephemeral_lifecycle_incompatible_with_session_state(): with pytest.raises(TypeError, match="ephemeral.*session"): class _Bad(A2AAgent): name = "bad" description = "" lifecycle = Lifecycle.EPHEMERAL state = State.SESSION state_model = _SessionState def test_runtime_metadata_propagates_to_card(): class _ChatCfg(BaseModel): pass class _Chat(A2AAgent[_ChatCfg, JWTAuth]): name = "chat" description = "stateful chat agent" auth_model = JWTAuth lifecycle = Lifecycle.SESSION state = State.SESSION state_model = _SessionState resources = Resources(cpu="2", memory="4Gi", gpu=0, max_runtime_seconds=1800) concurrency = 4 egress = EgressPolicy( allow_hosts=("api.openai.com",), allow_internal_services=("litellm.llm.svc.cluster.local",), ) tools_used = ("litellm", "minio") @skill(timeout_seconds=30, idempotent=True, max_retries=2, cost_class="cheap") async def reply(self, ctx: RunContext[JWTAuth], message: str) -> str: return f"echo {message}" card = _Chat().card() assert card.runtime.lifecycle is Lifecycle.SESSION assert card.runtime.state is State.SESSION assert card.runtime.sandbox is Sandbox.MICROSANDBOX assert card.runtime.resources.cpu == "2" assert card.runtime.concurrency == 4 assert card.runtime.egress.allow_hosts == ("api.openai.com",) assert card.runtime.tools_used == ("litellm", "minio") assert card.state_schema is not None assert "history" in card.state_schema["properties"] skill_card = card.skills[0] assert skill_card.policy.timeout_seconds == 30 assert skill_card.policy.idempotent is True assert skill_card.policy.max_retries == 2 assert skill_card.policy.cost_class == "cheap" def test_skill_metadata_propagates_to_card(): class _ScopedConfig(BaseModel): pass class _Scoped(A2AAgent[_ScopedConfig, APIKeyAuth]): name = "scoped" description = "scope test" config_model = _ScopedConfig auth_model = APIKeyAuth @skill(scopes=["a:read", "a:write"], stream=True, tags=["x"]) async def do(self, ctx: RunContext[APIKeyAuth]) -> str: return "ok" card = _Scoped().card() s = card.skills[0] assert s.scopes == ["a:read", "a:write"] assert s.stream is True assert s.tags == ["x"]