"""Tests for openclaw-proxy WebSocket middleware."""

import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from starlette.testclient import TestClient

from main import (
    SessionState,
    app,
    process_stt,
    call_openclaw,
    process_tts,
    MAX_AUDIO_BUFFER_BYTES,
)


# ── Helpers ──────────────────────────────────────────────────────────

MOCK_OPENCLAW_OUTPUT = json.dumps({
    "status": "ok",
    "result": {
        "payloads": [{"text": "你好！我在这儿。", "mediaUrl": None}],
    },
}).encode()


def _make_mock_process(
    stdout: bytes = MOCK_OPENCLAW_OUTPUT,
    returncode: int = 0,
) -> AsyncMock:
    """Build a mock asyncio.Process with preset stdout/stderr."""
    proc = AsyncMock()
    proc.communicate.return_value = (stdout, b"")
    proc.returncode = returncode
    return proc


def _key_down_msg() -> str:
    return json.dumps({"type": "hardware_event", "action": "key_down"})


def _key_up_msg() -> str:
    return json.dumps({"type": "hardware_event", "action": "key_up"})


def _wake_word_msg() -> str:
    return json.dumps({"type": "hardware_event", "action": "wake_word"})


def _pipeline_patches():
    """Context manager that patches STT/LLM/TTS pipeline functions."""
    return (
        patch(
            "main.process_stt",
            new_callable=AsyncMock,
            return_value="测试语音输入",
        ),
        patch(
            "main.call_openclaw",
            new_callable=AsyncMock,
            return_value="你好！我在这儿。",
        ),
        patch(
            "main.process_tts",
            new_callable=AsyncMock,
            return_value=b"\xfc\xff\xfe",
        ),
    )


# ── Unit tests: pipeline functions ───────────────────────────────────


@pytest.mark.asyncio
async def test_process_stt_calls_whisper() -> None:
    """STT should write audio to temp file, call Whisper, and return stripped text."""
    mock_model = MagicMock()
    mock_model.transcribe.return_value = {"text": " 你好世界 "}
    with patch("main._get_whisper_model", return_value=mock_model):
        result = await process_stt(b"\x00" * 80)
    assert result == "你好世界"
    mock_model.transcribe.assert_called_once()
    args, kwargs = mock_model.transcribe.call_args
    assert args[0].endswith(".opus")
    assert kwargs == {"language": "zh"}


@pytest.mark.asyncio
async def test_call_openclaw_parses_cli_output() -> None:
    """call_openclaw should parse the CLI JSON and return reply text."""
    mock_proc = _make_mock_process()
    with patch("main.asyncio.create_subprocess_exec", return_value=mock_proc):
        result = await call_openclaw("你好", "test-session")
    assert result == "你好！我在这儿。"


@pytest.mark.asyncio
async def test_call_openclaw_raises_on_cli_failure() -> None:
    """call_openclaw should raise RuntimeError when CLI exits non-zero."""
    mock_proc = _make_mock_process(stdout=b"", returncode=1)
    mock_proc.communicate.return_value = (b"", b"some error")
    with patch("main.asyncio.create_subprocess_exec", return_value=mock_proc):
        with pytest.raises(RuntimeError, match="OpenClaw CLI error"):
            await call_openclaw("你好", "test-session")


@pytest.mark.asyncio
async def test_call_openclaw_raises_on_empty_payloads() -> None:
    """call_openclaw should raise RuntimeError on empty payloads."""
    empty = json.dumps({"status": "ok", "result": {"payloads": []}}).encode()
    mock_proc = _make_mock_process(stdout=empty)
    with patch("main.asyncio.create_subprocess_exec", return_value=mock_proc):
        with pytest.raises(RuntimeError, match="empty payloads"):
            await call_openclaw("你好", "test-session")


@pytest.mark.asyncio
async def test_process_tts_returns_opus_bytes() -> None:
    """TTS should collect MP3 from edge-tts, convert to Opus via ffmpeg."""

    async def mock_stream() -> None:  # type: ignore[override]
        yield {"type": "audio", "data": b"mp3-part1"}
        yield {"type": "audio", "data": b"mp3-part2"}
        yield {"type": "WordBoundary", "data": None}  # non-audio chunk

    mock_comm = MagicMock()
    mock_comm.stream = mock_stream

    mock_proc = AsyncMock()
    mock_proc.communicate.return_value = (b"fake-opus", b"")
    mock_proc.returncode = 0

    with (
        patch("main.edge_tts.Communicate", return_value=mock_comm),
        patch("main.asyncio.create_subprocess_exec", return_value=mock_proc),
    ):
        result = await process_tts("测试文本")

    assert result == b"fake-opus"
    mock_proc.communicate.assert_called_once_with(input=b"mp3-part1mp3-part2")


@pytest.mark.asyncio
async def test_process_tts_raises_on_ffmpeg_failure() -> None:
    """TTS should raise RuntimeError when ffmpeg exits non-zero."""

    async def mock_stream() -> None:  # type: ignore[override]
        yield {"type": "audio", "data": b"mp3-data"}

    mock_comm = MagicMock()
    mock_comm.stream = mock_stream

    mock_proc = AsyncMock()
    mock_proc.communicate.return_value = (b"", b"ffmpeg error msg")
    mock_proc.returncode = 1

    with (
        patch("main.edge_tts.Communicate", return_value=mock_comm),
        patch("main.asyncio.create_subprocess_exec", return_value=mock_proc),
    ):
        with pytest.raises(RuntimeError, match="ffmpeg TTS conversion error"):
            await process_tts("测试文本")


# ── Unit tests: SessionState ─────────────────────────────────────────


def test_session_state_defaults() -> None:
    """SessionState should start idle with empty audio_chunks."""
    state = SessionState()
    assert state.status == "idle"
    assert state.audio_chunks == []
    assert state.buffer_size() == 0


# ── Integration tests: HTTP & WebSocket ──────────────────────────────


def test_health_endpoint() -> None:
    """GET /health should return 200 with ok status."""
    client = TestClient(app)
    resp = client.get("/health")
    assert resp.status_code == 200
    assert resp.json() == {"status": "ok"}


def test_websocket_connect_and_disconnect() -> None:
    """Client should be able to connect to /chat and cleanly disconnect."""
    client = TestClient(app)
    with client.websocket_connect("/chat"):
        pass


def test_websocket_text_frame_wake_word() -> None:
    """Sending a wake_word JSON should get back a 'listening' state."""
    client = TestClient(app)
    with client.websocket_connect("/chat") as ws:
        ws.send_text(_wake_word_msg())
        resp = json.loads(ws.receive_text())
        assert resp["type"] == "state"
        assert resp["status"] == "listening"


def test_websocket_text_frame_malformed_json() -> None:
    """Malformed JSON should not crash the connection."""
    client = TestClient(app)
    with client.websocket_connect("/chat") as ws:
        ws.send_text("not-valid-json{{{")
        ws.send_text(_wake_word_msg())
        resp = json.loads(ws.receive_text())
        assert resp["status"] == "listening"


def test_key_down_starts_listening() -> None:
    """key_down should transition to listening and return state notification."""
    client = TestClient(app)
    with client.websocket_connect("/chat") as ws:
        ws.send_text(_key_down_msg())
        resp = json.loads(ws.receive_text())
        assert resp == {"type": "state", "status": "listening"}


def test_key_up_triggers_pipeline() -> None:
    """key_down → binary → key_up should trigger full STT→LLM→TTS pipeline."""
    p_stt, p_llm, p_tts = _pipeline_patches()
    client = TestClient(app)
    with p_stt, p_llm, p_tts:
        with client.websocket_connect("/chat") as ws:
            ws.send_text(_key_down_msg())
            ws.receive_text()  # consume "listening"

            ws.send_bytes(b"\x00" * 80)

            ws.send_text(_key_up_msg())

            # 1) thinking state
            msg1 = json.loads(ws.receive_text())
            assert msg1 == {"type": "state", "status": "thinking"}

            # 2) text reply
            msg2 = json.loads(ws.receive_text())
            assert msg2["type"] == "text_reply"
            assert msg2["text"] == "你好！我在这儿。"

            # 3) TTS audio bytes
            audio = ws.receive_bytes()
            assert audio == b"\xfc\xff\xfe"


def test_key_up_empty_buffer_skips_pipeline() -> None:
    """key_down → key_up with no audio should skip pipeline, connection alive."""
    p_stt, p_llm, p_tts = _pipeline_patches()
    client = TestClient(app)
    with p_stt as mock_stt, p_llm, p_tts:
        with client.websocket_connect("/chat") as ws:
            ws.send_text(_key_down_msg())
            ws.receive_text()  # consume "listening"

            ws.send_text(_key_up_msg())

            # Pipeline should NOT have been called
            mock_stt.assert_not_called()

            # Connection should still be alive — send another event
            ws.send_text(_wake_word_msg())
            resp = json.loads(ws.receive_text())
            assert resp["status"] == "listening"


def test_binary_frame_while_idle_is_dropped() -> None:
    """Binary frame in idle state should be dropped, connection alive."""
    p_stt, p_llm, p_tts = _pipeline_patches()
    client = TestClient(app)
    with p_stt as mock_stt, p_llm, p_tts:
        with client.websocket_connect("/chat") as ws:
            ws.send_bytes(b"\x00" * 80)

            # Pipeline should NOT have been called
            mock_stt.assert_not_called()

            # Connection should still be alive
            ws.send_text(_wake_word_msg())
            resp = json.loads(ws.receive_text())
            assert resp["status"] == "listening"


def test_binary_frames_are_concatenated() -> None:
    """Multiple binary frames should be concatenated before STT."""
    p_stt, p_llm, p_tts = _pipeline_patches()
    client = TestClient(app)
    with p_stt as mock_stt, p_llm, p_tts:
        with client.websocket_connect("/chat") as ws:
            ws.send_text(_key_down_msg())
            ws.receive_text()  # consume "listening"

            ws.send_bytes(b"\x01" * 10)
            ws.send_bytes(b"\x02" * 20)
            ws.send_bytes(b"\x03" * 30)

            ws.send_text(_key_up_msg())

            # Consume pipeline responses
            ws.receive_text()  # thinking
            ws.receive_text()  # text_reply
            ws.receive_bytes()  # audio

            # STT should have received concatenated audio
            mock_stt.assert_called_once_with(
                b"\x01" * 10 + b"\x02" * 20 + b"\x03" * 30
            )


def test_repeated_key_down_resets_buffer() -> None:
    """Second key_down should reset buffer — only second chunk reaches STT."""
    p_stt, p_llm, p_tts = _pipeline_patches()
    client = TestClient(app)
    with p_stt as mock_stt, p_llm, p_tts:
        with client.websocket_connect("/chat") as ws:
            ws.send_text(_key_down_msg())
            ws.receive_text()  # consume "listening"
            ws.send_bytes(b"\xAA" * 50)

            # Second key_down resets
            ws.send_text(_key_down_msg())
            ws.receive_text()  # consume second "listening"
            ws.send_bytes(b"\xBB" * 30)

            ws.send_text(_key_up_msg())

            ws.receive_text()  # thinking
            ws.receive_text()  # text_reply
            ws.receive_bytes()  # audio

            # Only the second chunk should be passed to STT
            mock_stt.assert_called_once_with(b"\xBB" * 30)


def test_key_up_without_key_down_is_ignored() -> None:
    """key_up in idle state should be ignored, connection alive."""
    p_stt, p_llm, p_tts = _pipeline_patches()
    client = TestClient(app)
    with p_stt as mock_stt, p_llm, p_tts:
        with client.websocket_connect("/chat") as ws:
            ws.send_text(_key_up_msg())

            # Pipeline should NOT have been called
            mock_stt.assert_not_called()

            # Connection should still be alive
            ws.send_text(_wake_word_msg())
            resp = json.loads(ws.receive_text())
            assert resp["status"] == "listening"


def test_websocket_binary_frame_full_pipeline() -> None:
    """key_down → binary → key_up should trigger STT → LLM → TTS pipeline.

    Expected response sequence:
      1. {"type": "state", "status": "thinking"}
      2. {"type": "text_reply", "text": "..."}
      3. binary bytes (TTS audio)
    """
    p_stt, p_llm, p_tts = _pipeline_patches()
    client = TestClient(app)
    with p_stt, p_llm, p_tts:
        with client.websocket_connect("/chat") as ws:
            ws.send_text(_key_down_msg())
            ws.receive_text()  # consume "listening"

            ws.send_bytes(b"\x00" * 80)

            ws.send_text(_key_up_msg())

            # 1) thinking state
            msg1 = json.loads(ws.receive_text())
            assert msg1 == {"type": "state", "status": "thinking"}

            # 2) text reply
            msg2 = json.loads(ws.receive_text())
            assert msg2["type"] == "text_reply"
            assert msg2["text"] == "你好！我在这儿。"

            # 3) TTS audio bytes
            audio = ws.receive_bytes()
            assert audio == b"\xfc\xff\xfe"
