WalkXR-AI / app / agent_tester.py
agent_tester.py
Raw
import streamlit as st
import requests

API_URL = "http://127.0.0.1:8000/chat"

st.set_page_config(page_title="WalkXR Agent Tester", page_icon="🧠")

st.title("WalkXR Agent Tester")
st.write("Test your conversational agent below. Each turn shows debug info including context and the generated prompt.")

# --- Step 1: Initialize session state ---
if "messages" not in st.session_state:
    st.session_state.messages = []

if "history" not in st.session_state:
    st.session_state.history = []  # list of (user_msg, agent_msg) tuples

# --- Step 2: Display past chat ---
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])
        if msg["role"] == "assistant" and "debug" in msg:
            with st.expander("🛠 Debug Info"):
                if msg["debug"].get("prompt"):
                    st.markdown("**Prompt**")
                    st.code(msg["debug"]["prompt"])
                if msg["debug"].get("source_chunks"):
                    st.markdown("**Retrieved Chunks**")
                    for chunk in msg["debug"]["source_chunks"]:
                        st.markdown(f"- {chunk}")
                if msg["debug"].get("langsmith_trace"):
                    st.markdown("**LangSmith Trace**")
                    st.json(msg["debug"]["langsmith_trace"])

# --- Step 3: Chat input ---
if prompt := st.chat_input("Type a message..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    st.session_state.history.append((prompt, ""))  # placeholder for response

    with st.chat_message("user"):
        st.markdown(prompt)

    # --- Step 4: Prepare formatted history AFTER it's initialized ---
    api_history = [(user, agent) for user, agent in st.session_state.history]

    payload = {
        "user_id": "test-session",
        "stage": "demo",
        "message": prompt,
        "history": api_history
    }

    # --- Step 5: Call FastAPI backend ---
    try:
        res = requests.post(API_URL, json=payload)

        if res.status_code == 200:
            data = res.json()
            response_text = data.get("response_text", "[No reply]")
            source_chunks = data.get("source_chunks", [])
            debug = data.get("debug", {})

            st.session_state.messages.append({
                "role": "assistant",
                "content": response_text,
                "debug": {
                    "prompt": debug.get("prompt", ""),
                    "source_chunks": source_chunks,
                    "langsmith_trace": debug.get("langsmith_trace", {})
                }
            })

            # Update the last history entry now that we have a real response
            st.session_state.history[-1] = (prompt, response_text)

            with st.chat_message("assistant"):
                st.markdown(response_text)
                with st.expander("🛠 Debug Info"):
                    if debug.get("prompt"):
                        st.markdown("**Prompt**")
                        st.code(debug["prompt"])
                    if source_chunks:
                        st.markdown("**Retrieved Chunks**")
                        for chunk in source_chunks:
                            st.markdown(f"- {chunk}")
                    if debug.get("langsmith_trace"):
                        st.markdown("**LangSmith Trace**")
                        st.json(debug["langsmith_trace"])

        else:
            st.error(f"Error {res.status_code}: {res.text}")

    except Exception as e:
        st.error(f"Failed to connect to backend: {e}")

with st.sidebar:
    st.markdown("## Developer Debug Panel")
    show_debug = st.checkbox("Show full debug trace", value=False)

    if show_debug:
        for i, msg in enumerate(st.session_state.messages):
            if msg["role"] == "assistant" and "debug" in msg:
                st.markdown(f"### Turn {i//2 + 1}")
                if msg["debug"].get("prompt"):
                    st.markdown("**Prompt**")
                    st.code(msg["debug"]["prompt"])
                if msg["debug"].get("source_chunks"):
                    st.markdown("**Source Chunks**")
                    for chunk in msg["debug"]["source_chunks"]:
                        st.markdown(f"- {chunk}")
                if msg["debug"].get("langsmith_trace"):
                    st.markdown("**LangSmith Trace**")
                    st.json(msg["debug"]["langsmith_trace"])