"""openclaw-proxy — async WebSocket gateway between XiaoZhi ESP32 and OpenClaw agent."""

import asyncio
import json
import logging
import tempfile
import uuid
from enum import Enum
from shutil import which
from typing import Literal

import edge_tts
import whisper
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from pydantic import BaseModel, ValidationError

# ── Logging ──────────────────────────────────────────────────────────

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("openclaw-proxy")

# ── Configuration ────────────────────────────────────────────────────

OPENCLAW_CLI = which("openclaw") or "openclaw"
OPENCLAW_AGENT_ID = "main"
OPENCLAW_TIMEOUT = 120  # seconds
WHISPER_MODEL_SIZE = "base"
EDGE_TTS_VOICE = "zh-CN-XiaoxiaoNeural"
FFMPEG_BIN = which("ffmpeg") or "ffmpeg"
MAX_AUDIO_BUFFER_BYTES: int = 10 * 1024 * 1024  # 10 MB


# ── Pydantic models ─────────────────────────────────────────────────


class HardwareAction(str, Enum):
    """Actions reported by XiaoZhi ESP32 hardware."""

    KEY_DOWN = "key_down"
    KEY_UP = "key_up"
    WAKE_WORD = "wake_word"


class HardwareEvent(BaseModel):
    """Incoming JSON frame from hardware."""

    type: str
    action: HardwareAction


class StateNotification(BaseModel):
    """Outgoing state frame pushed to hardware (e.g. breathing LED trigger)."""

    type: str = "state"
    status: str


class TextReply(BaseModel):
    """Outgoing text reply carrying the LLM response."""

    type: str = "text_reply"
    text: str


class SessionState(BaseModel):
    """Per-connection state machine for press-to-talk buffering."""

    model_config = {"arbitrary_types_allowed": True}

    status: Literal["idle", "listening", "processing"] = "idle"
    audio_chunks: list[bytes] = []

    def buffer_size(self) -> int:
        """Return total buffered audio size in bytes."""
        return sum(len(c) for c in self.audio_chunks)

    def drain_audio(self) -> bytes:
        """Join and clear all buffered audio chunks."""
        audio = b"".join(self.audio_chunks)
        self.audio_chunks.clear()
        return audio

    def reset(self) -> None:
        """Reset state to idle with empty buffer."""
        self.status = "idle"
        self.audio_chunks.clear()


# ── Whisper model (lazy singleton) ─────────────────────────────────

_whisper_model: whisper.Whisper | None = None


def _get_whisper_model() -> whisper.Whisper:
    """Return the Whisper model, loading it on first call."""
    global _whisper_model
    if _whisper_model is None:
        logger.info("Loading Whisper '%s' model…", WHISPER_MODEL_SIZE)
        _whisper_model = whisper.load_model(WHISPER_MODEL_SIZE)
    return _whisper_model


# ── STT / TTS pipeline ────────────────────────────────────────────


async def process_stt(audio_bytes: bytes) -> str:
    """Convert Opus audio to text via Whisper.

    Args:
        audio_bytes: Raw Opus-encoded audio from ESP32.

    Returns:
        Transcribed text string.
    """
    logger.info("STT: received %d bytes of Opus audio", len(audio_bytes))
    model = _get_whisper_model()
    with tempfile.NamedTemporaryFile(suffix=".opus", delete=True) as f:
        f.write(audio_bytes)
        f.flush()
        result = await asyncio.to_thread(
            model.transcribe, f.name, language="zh"
        )
    text = result["text"].strip()
    logger.info("STT: transcribed → %s", text)
    return text


async def call_openclaw(text: str, session_id: str) -> str:
    """Send text to OpenClaw agent via CLI and return its reply.

    Args:
        text: User utterance transcribed by STT.
        session_id: Per-connection session ID for multi-turn conversation.

    Returns:
        Agent reply string.

    Raises:
        RuntimeError: If the CLI call fails or returns unexpected output.
    """
    logger.info("LLM: sending to OpenClaw → %s", text)

    proc = await asyncio.create_subprocess_exec(
        OPENCLAW_CLI, "agent",
        "--agent", OPENCLAW_AGENT_ID,
        "--session-id", session_id,
        "-m", text,
        "--json",
        "--timeout", str(OPENCLAW_TIMEOUT),
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE,
    )
    stdout, stderr = await proc.communicate()

    if proc.returncode != 0:
        err_msg = stderr.decode(errors="replace").strip()
        logger.error("OpenClaw CLI failed (rc=%d): %s", proc.returncode, err_msg)
        raise RuntimeError(f"OpenClaw CLI error: {err_msg}")

    data = json.loads(stdout.decode())
    payloads = data.get("result", {}).get("payloads", [])
    if not payloads:
        raise RuntimeError("OpenClaw returned empty payloads")

    reply = payloads[0].get("text", "")
    logger.info("LLM: reply ← %s", reply)
    return reply


async def process_tts(text: str) -> bytes:
    """Convert text to Opus audio via edge-tts and ffmpeg.

    Args:
        text: Text to synthesize.

    Returns:
        Opus-encoded audio bytes.

    Raises:
        RuntimeError: If ffmpeg conversion fails.
    """
    logger.info("TTS: synthesising → %s", text)

    # 1. edge-tts → MP3
    communicate = edge_tts.Communicate(text=text, voice=EDGE_TTS_VOICE)
    mp3_chunks: list[bytes] = []
    async for chunk in communicate.stream():
        if chunk["type"] == "audio":
            mp3_chunks.append(chunk["data"])
    mp3_bytes = b"".join(mp3_chunks)

    # 2. MP3 → Opus via ffmpeg
    proc = await asyncio.create_subprocess_exec(
        FFMPEG_BIN, "-i", "pipe:0",
        "-c:a", "libopus", "-b:a", "24k",
        "-ar", "16000", "-ac", "1",
        "-f", "opus", "pipe:1",
        stdin=asyncio.subprocess.PIPE,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE,
    )
    opus_bytes, stderr = await proc.communicate(input=mp3_bytes)

    if proc.returncode != 0:
        err_msg = stderr.decode(errors="replace").strip()
        logger.error("ffmpeg failed (rc=%d): %s", proc.returncode, err_msg)
        raise RuntimeError(f"ffmpeg TTS conversion error: {err_msg}")

    logger.info("TTS: generated %d bytes of Opus audio", len(opus_bytes))
    return opus_bytes


# ── Frame router ────────────────────────────────────────────────────


async def handle_text_frame(
    ws: WebSocket, raw: str, state: SessionState, session_id: str,
) -> None:
    """Parse and route an incoming JSON text frame.

    Args:
        ws: Active WebSocket connection.
        raw: Raw text payload received from client.
        state: Per-connection session state.
        session_id: Per-connection session ID.
    """
    try:
        payload = json.loads(raw)
    except json.JSONDecodeError:
        logger.warning("Malformed JSON (ignored): %.120s", raw)
        return

    try:
        event = HardwareEvent(**payload)
    except ValidationError as exc:
        logger.warning("Invalid hardware event: %s", exc.errors())
        return

    logger.info("Hardware event: %s", event.action.value)

    if event.action in (HardwareAction.KEY_DOWN, HardwareAction.WAKE_WORD):
        if state.status == "listening":
            logger.info("Re-key_down while listening — resetting buffer")
            state.audio_chunks.clear()
        state.status = "listening"
        ack = StateNotification(status="listening")
        await ws.send_text(ack.model_dump_json())

    elif event.action == HardwareAction.KEY_UP:
        if state.status != "listening":
            logger.warning("key_up in '%s' state — ignored", state.status)
            return
        state.status = "processing"
        await handle_pipeline(ws, state, session_id)


async def handle_binary_frame(ws: WebSocket, data: bytes, state: SessionState) -> None:
    """Buffer an audio binary frame if in listening state.

    Args:
        ws: Active WebSocket connection.
        data: Raw Opus audio bytes.
        state: Per-connection session state.
    """
    if state.status != "listening":
        logger.warning("Binary frame in '%s' state — dropped", state.status)
        return

    if state.buffer_size() + len(data) > MAX_AUDIO_BUFFER_BYTES:
        logger.warning("Audio buffer full (%d bytes) — dropping frame", state.buffer_size())
        return

    state.audio_chunks.append(data)
    logger.info("Buffered %d bytes (total: %d)", len(data), state.buffer_size())


async def handle_pipeline(
    ws: WebSocket, state: SessionState, session_id: str,
) -> None:
    """Drain buffered audio and run STT → LLM → TTS pipeline.

    Args:
        ws: Active WebSocket connection.
        state: Per-connection session state.
        session_id: Per-connection session ID.
    """
    audio = state.drain_audio()
    if not audio:
        logger.warning("key_up with empty buffer — skipping pipeline")
        state.reset()
        return

    try:
        # Stage 1 — STT
        text = await process_stt(audio)

        # Notify hardware: entering thinking phase
        await ws.send_text(StateNotification(status="thinking").model_dump_json())

        # Stage 2 — LLM (with periodic thinking heartbeats)
        async def heartbeat() -> None:
            """Send thinking heartbeat every 3 seconds while LLM is working."""
            while True:
                await asyncio.sleep(3)
                try:
                    await ws.send_text(
                        StateNotification(status="thinking").model_dump_json()
                    )
                except Exception:
                    break

        heartbeat_task = asyncio.create_task(heartbeat())
        try:
            reply = await call_openclaw(text, session_id)
        finally:
            heartbeat_task.cancel()

        await ws.send_text(TextReply(text=reply).model_dump_json())

        # Stage 3 — TTS
        tts_audio = await process_tts(reply)
        logger.info("Pushing %d bytes of audio to client", len(tts_audio))
        await ws.send_bytes(tts_audio)
    finally:
        state.reset()


# ── FastAPI application ─────────────────────────────────────────────

app = FastAPI(title="openclaw-proxy", version="0.1.0")


@app.websocket("/chat")
async def websocket_chat(ws: WebSocket) -> None:
    """Main WebSocket endpoint for XiaoZhi ESP32 hardware."""
    await ws.accept()
    peer = f"{ws.client.host}:{ws.client.port}" if ws.client else "unknown"
    session_id = f"xiaozhi-{uuid.uuid4().hex[:8]}"
    state = SessionState()
    logger.info("Client connected: %s (session=%s)", peer, session_id)

    try:
        while True:
            message = await ws.receive()

            if message["type"] == "websocket.disconnect":
                break

            text_data = message.get("text")
            byte_data = message.get("bytes")

            if text_data is not None:
                await handle_text_frame(ws, text_data, state, session_id)
            elif byte_data is not None:
                await handle_binary_frame(ws, byte_data, state)
    except WebSocketDisconnect:
        pass
    finally:
        logger.info("Client disconnected: %s (session=%s)", peer, session_id)


@app.get("/health")
async def health() -> dict[str, str]:
    """Liveness probe."""
    return {"status": "ok"}


# ── Entry point ─────────────────────────────────────────────────────

if __name__ == "__main__":
    import uvicorn

    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
