Chat Engine — FastAPI + LangGraph
Level: Intermediate Pre-reading: 06 · LangGraph Agent · 09 · HITL Design
This document covers the FastAPI-based chat engine that serves as the user-facing interface to the LangGraph agent. It manages WebSocket sessions, streams agent responses, and handles human-in-the-loop resume calls.
Chat Engine Architecture
sequenceDiagram
participant U as User (Browser / CLI)
participant WS as FastAPI WebSocket /chat
participant LG as LangGraph Agent
participant PG as PostgreSQL Checkpointer
participant MCP as MCP Servers
U->>WS: Connect + send {"ticket_key": "TASK-101"}
WS->>LG: graph.stream(initial_state, config={thread_id})
loop Agent nodes execute
LG-->>WS: stream event (message update)
WS-->>U: {"type": "message", "content": "..."}
end
LG->>PG: save state at interrupt()
LG-->>WS: interrupt event
WS-->>U: {"type": "awaiting_approval", "diff": "..."}
U->>WS: {"type": "resume", "response": "approve"}
WS->>LG: graph.invoke(Command(resume="approve"), config)
LG->>PG: load state
loop Remaining nodes
LG->>MCP: create_branch · commit_file · create_pr
LG-->>WS: stream events
WS-->>U: {"type": "message", "content": "PR created: ..."}
end
WS-->>U: {"type": "done", "pr_url": "https://..."}
FastAPI Application
# chat_engine/main.py
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import asyncio
import json
import uuid
import logging
from langgraph.types import Command
from agent.graph import build_agent_graph
from agent.state import AgentState
app = FastAPI(title="TaskMaster Chat Engine")
logger = logging.getLogger(__name__)
app.add_middleware(CORSMiddleware,
allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# Build agent graph once at startup
agent_graph = build_agent_graph()
# In-memory session store (use Redis in production)
active_sessions: dict[str, dict] = {}
# ─── WebSocket Chat Endpoint ───────────────────────────────────────────────────
@app.websocket("/chat")
async def chat_websocket(websocket: WebSocket):
"""Main conversational endpoint. Handles full agent lifecycle over WebSocket."""
await websocket.accept()
thread_id = str(uuid.uuid4())
try:
# Wait for the initial message with ticket key
raw = await websocket.receive_text()
initial_msg = json.loads(raw)
ticket_key = initial_msg.get("ticket_key", "").strip().upper()
if not ticket_key:
await websocket.send_json({"type": "error", "message": "Please provide a JIRA ticket key."})
return
await websocket.send_json({
"type": "system",
"message": f"🤖 Starting agent for **{ticket_key}**...",
"thread_id": thread_id
})
# Store session
active_sessions[thread_id] = {
"ticket_key": ticket_key,
"websocket": websocket,
"status": "running"
}
# Run the agent — stream events back to the client
initial_state = AgentState(
ticket_key=ticket_key,
thread_id=thread_id,
messages=[],
iteration_count=0
)
config = {"configurable": {"thread_id": thread_id}}
async for event in agent_graph.astream_events(initial_state, config, version="v2"):
await _handle_stream_event(websocket, thread_id, event)
# Agent completed (no interrupt hit, or was resumed and finished)
active_sessions[thread_id]["status"] = "done"
final_state = agent_graph.get_state(config)
await websocket.send_json({
"type": "done",
"pr_url": final_state.values.get("pr_url"),
"message": "✅ All done! Check your GitHub repo and JIRA ticket."
})
except WebSocketDisconnect:
logger.info(f"Client disconnected: thread_id={thread_id}")
except Exception as e:
logger.exception(f"Agent error for thread {thread_id}: {e}")
await websocket.send_json({"type": "error", "message": f"Agent error: {str(e)}"})
finally:
active_sessions.pop(thread_id, None)
async def _handle_stream_event(websocket: WebSocket, thread_id: str, event: dict):
"""Translate LangGraph stream events into WebSocket messages."""
kind = event.get("event")
if kind == "on_chain_start":
node = event["name"]
await websocket.send_json({"type": "node_start", "node": node})
elif kind == "on_chain_end":
# Check for new messages in state
output = event.get("data", {}).get("output", {})
messages = output.get("messages", [])
for msg in messages[-1:]: # send the latest message only
await websocket.send_json({
"type": "message",
"content": msg.content if hasattr(msg, 'content') else str(msg)
})
elif kind == "on_custom_event" and event.get("name") == "interrupt":
# HITL interrupt — surface the diff to the user
data = event.get("data", {})
active_sessions[thread_id]["status"] = "awaiting_approval"
await websocket.send_json({
"type": "awaiting_approval",
"prompt": data.get("prompt"),
"diff_summary": data.get("diff_summary"),
"ticket_key": data.get("ticket_key")
})
# ─── Resume Endpoint (after HITL approval) ────────────────────────────────────
class ResumeRequest(BaseModel):
response: str # "approve", "reject", or feedback text
@app.post("/threads/{thread_id}/resume")
async def resume_thread(thread_id: str, body: ResumeRequest):
"""Resume a paused agent thread with the user's approval decision."""
session = active_sessions.get(thread_id)
if not session:
raise HTTPException(404, detail=f"Thread {thread_id} not found or already completed")
if session["status"] != "awaiting_approval":
raise HTTPException(400, detail=f"Thread {thread_id} is not awaiting approval (status: {session['status']})")
config = {"configurable": {"thread_id": thread_id}}
websocket: WebSocket = session["websocket"]
session["status"] = "running"
await websocket.send_json({
"type": "system",
"message": f"▶️ Resuming with: **{body.response}**"
})
# Resume the graph from the interrupt checkpoint
async for event in agent_graph.astream_events(
Command(resume=body.response), config, version="v2"
):
await _handle_stream_event(websocket, thread_id, event)
return {"status": "resumed", "thread_id": thread_id}
# ─── Health and Status Endpoints ──────────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/threads/{thread_id}/status")
def thread_status(thread_id: str):
config = {"configurable": {"thread_id": thread_id}}
state = agent_graph.get_state(config)
session = active_sessions.get(thread_id, {})
return {
"thread_id": thread_id,
"status": session.get("status", "unknown"),
"current_node": state.next if state else None,
"ticket_key": state.values.get("ticket_key") if state else None,
"pr_url": state.values.get("pr_url") if state else None
}
Prompt Templates
# agent/prompts.py
SYSTEM_PROMPT = """You are the TaskMaster AI development agent. You help software teams
automatically resolve JIRA tickets by reading the ticket, understanding the codebase,
generating the correct code changes, and creating a Pull Request.
Rules you MUST follow:
1. Only modify files in modules that are relevant to the ticket
2. Never delete existing functionality — only add or fix
3. Always write tests for every change you make
4. Never push directly to main — always create a feature branch
5. Always wait for human approval before creating the PR
When you identify a root cause, explain it clearly in plain English before proposing a fix.
When implementing a story, enumerate each acceptance criterion and confirm it is covered.
"""
MODULE_IDENTIFICATION_PROMPT = """Given a JIRA ticket, identify which modules of the
TaskMaster project need to change.
Module descriptions:
- taskmaster-core: Domain entities (Task.java), JPA repository, TaskService
- taskmaster-api: REST controllers (TaskController), request/response DTOs
- taskmaster-e2e: Playwright TypeScript E2E tests (runs against live API)
Decision rules:
- Bug in service/domain logic → taskmaster-core only
- New field on entity + exposed via API → taskmaster-core + taskmaster-api + taskmaster-e2e
- Bug in controller/DTO → taskmaster-api only
- Test-only change → taskmaster-e2e only
Always apply minimum scope. If unsure between one module and two, ask for clarification.
"""
Simple CLI Client
For local testing without a browser UI:
#!/usr/bin/env python3
# chat_client.py — terminal chat interface for the TaskMaster agent
import asyncio
import json
import websockets
import sys
import requests
CHAT_ENGINE_URL = "ws://localhost:8080/chat"
RESUME_URL = "http://localhost:8080/threads/{thread_id}/resume"
async def chat(ticket_key: str):
async with websockets.connect(CHAT_ENGINE_URL) as ws:
# Send the ticket key
await ws.send(json.dumps({"ticket_key": ticket_key}))
thread_id = None
print(f"\n🤖 Agent started for {ticket_key}\n{'─'*60}\n")
while True:
raw = await ws.recv()
msg = json.loads(raw)
msg_type = msg.get("type")
if msg_type == "system":
thread_id = msg.get("thread_id", thread_id)
print(f"🔧 {msg['message']}")
elif msg_type == "node_start":
print(f"\n▶ Running: {msg['node']}", end="", flush=True)
elif msg_type == "message":
print(f"\n{msg['content']}")
elif msg_type == "awaiting_approval":
print(f"\n{'═'*60}")
print(msg.get("diff_summary", msg.get("prompt")))
print(f"{'═'*60}")
response = input("\n✍️ Your decision (approve / reject / feedback): ").strip()
# Resume via HTTP POST (simulates a separate client action)
result = requests.post(
RESUME_URL.format(thread_id=thread_id),
json={"response": response}
)
print(f"▶️ Resumed (status: {result.status_code})")
elif msg_type == "done":
pr_url = msg.get("pr_url", "N/A")
print(f"\n{'═'*60}")
print(f"✅ DONE! PR: {pr_url}")
print(f"{'═'*60}\n")
break
elif msg_type == "error":
print(f"\n❌ Error: {msg['message']}")
break
if __name__ == "__main__":
ticket = sys.argv[1] if len(sys.argv) > 1 else input("Enter JIRA ticket key: ").strip().upper()
asyncio.run(chat(ticket))
Usage:
# Terminal 1: start the chat engine
uvicorn main:app --host 0.0.0.0 --port 8080 --reload
# Terminal 2: run the CLI client
python3 chat_client.py TASK-101
Dockerfile for the Chat Engine
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
ENV PYTHONUNBUFFERED=1
EXPOSE 8080
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080", "--workers", "2"]
# requirements.txt
fastapi==0.111.0
uvicorn[standard]==0.30.1
websockets==12.0
langgraph==0.2.0
langchain-aws==0.1.6
boto3==1.34.0
psycopg2-binary==2.9.9
requests==2.31.0
pydantic==2.7.0
Why WebSocket instead of a simple HTTP POST endpoint?
The agent runs for 3–15 minutes per ticket with multiple checkpoints and messages. WebSocket enables real-time streaming of each node's progress, including the HITL gate prompt, without the client having to poll. HTTP long-polling works but is less clean.
How is thread state persisted so the HITL resume works if the server restarts?
LangGraph's PostgresSaver writes the full AgentState to the checkpoints table after every node. If the ECS task restarts, the state is reloaded from PostgreSQL on the next resume call using the thread_id. The only in-memory state is the active WebSocket connection.
What happens if the user closes their browser tab while the agent is running?
The WebSocket disconnects and active_sessions[thread_id] is cleared. The agent state is safely checkpointed in PostgreSQL. The user can reconnect later — hit GET /threads/{thread_id}/status to see where the agent paused, then call /threads/{thread_id}/resume to continue.