98 lines
3.4 KiB
Python
98 lines
3.4 KiB
Python
"""Graph agent: turns prompts into PNG charts inside a microsandbox VM.
|
|
|
|
Receives an a2a grant from the caller, derives the user's MinIO bucket
|
|
from the grant, asks the cluster sandbox runtime to spin a microVM with
|
|
that bucket FUSE/bridge-mounted at /workspace, runs matplotlib inside,
|
|
writes the PNG back into the bucket.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
|
|
import httpx
|
|
from pydantic import BaseModel
|
|
|
|
from a2a_pack import A2AAgent, NoAuth, RunContext, skill
|
|
|
|
SANDBOX_URL = os.environ.get(
|
|
"SANDBOX_URL", "http://sandbox.sandbox.svc.cluster.local:8000"
|
|
)
|
|
|
|
|
|
class GraphConfig(BaseModel):
|
|
image: str = "python:3.11-slim"
|
|
|
|
|
|
class GraphAgent(A2AAgent[GraphConfig, NoAuth]):
|
|
name = "graph-agent"
|
|
description = "Generate matplotlib charts in an isolated microVM, write PNG to /workspace/outputs/"
|
|
version = "0.1.0"
|
|
|
|
config_model = GraphConfig
|
|
auth_model = NoAuth
|
|
tools_used = ("microsandbox", "matplotlib")
|
|
|
|
@skill(
|
|
description="Render a chart and save it as a PNG in the caller's workspace.",
|
|
tags=["visualization", "chart", "spreadsheet"],
|
|
)
|
|
async def generate_chart(
|
|
self, ctx: RunContext[NoAuth], prompt: str
|
|
) -> dict:
|
|
# The grant we received gives us a workspace bound to the caller's
|
|
# bucket. We can ONLY see that bucket; nothing else.
|
|
bucket = getattr(ctx.workspace, "bucket", None)
|
|
if not bucket:
|
|
return {"error": "no workspace grant; refusing to run"}
|
|
|
|
await ctx.emit_progress(f"rendering '{prompt}' for bucket {bucket}")
|
|
|
|
# Embed the prompt as the chart title via an env var so we don't have
|
|
# to escape it inside a heredoc'd Python script.
|
|
script = (
|
|
"set -e\n"
|
|
"pip install -q --no-cache-dir matplotlib >/dev/null\n"
|
|
'python - <<\'PY\'\n'
|
|
"import os, matplotlib\n"
|
|
"matplotlib.use('Agg')\n"
|
|
"import matplotlib.pyplot as plt\n"
|
|
"os.makedirs('/workspace/outputs', exist_ok=True)\n"
|
|
"title = os.environ.get('CHART_TITLE', 'chart')\n"
|
|
"data = {'Q1': 12, 'Q2': 19, 'Q3': 15, 'Q4': 27}\n"
|
|
"fig, ax = plt.subplots(figsize=(8, 5))\n"
|
|
"ax.bar(list(data.keys()), list(data.values()), color='#4f46e5')\n"
|
|
"ax.set_title(title)\n"
|
|
"ax.set_ylabel('value')\n"
|
|
"fig.tight_layout()\n"
|
|
"fig.savefig('/workspace/outputs/chart.png', dpi=120)\n"
|
|
"print('wrote /workspace/outputs/chart.png')\n"
|
|
"PY\n"
|
|
)
|
|
|
|
async with httpx.AsyncClient(timeout=180.0) as c:
|
|
r = await c.post(
|
|
f"{SANDBOX_URL}/v1/run_shell",
|
|
json={
|
|
"bucket": bucket,
|
|
"script": f"export CHART_TITLE={prompt!r}; {script}",
|
|
"image": self.config.image,
|
|
"memory_mib": 1024,
|
|
"timeout_seconds": 150,
|
|
},
|
|
)
|
|
|
|
if r.status_code >= 400:
|
|
return {
|
|
"error": f"sandbox {r.status_code}",
|
|
"detail": r.text[:1000],
|
|
}
|
|
out = r.json()
|
|
await ctx.emit_progress("chart rendered")
|
|
return {
|
|
"prompt": prompt,
|
|
"bucket": bucket,
|
|
"chart_path": "outputs/chart.png",
|
|
"stdout": out.get("stdout", ""),
|
|
"exit_code": out.get("exit_code", -1),
|
|
}
|