"""Tests for automatic agent selection."""

from __future__ import annotations

import json

from src.agent_selector import auto_select_agents
from src.models import AgentConfig


SAMPLE_AGENTS: dict[str, AgentConfig] = {
    "architect": AgentConfig(name="architect", role="Software Architect", system_prompt="You are an architect."),
    "devops": AgentConfig(name="devops", role="DevOps Engineer", system_prompt="You are DevOps."),
    "business_analyst": AgentConfig(name="business_analyst", role="Business Analyst", system_prompt="You analyze business."),
}


class MockLLMClient:
    """Mock LLM client returning a canned response."""

    def __init__(self, response: str) -> None:
        self.response = response
        self.call_count = 0

    def chat(self, system: str, messages: list[dict[str, str]]) -> str:
        self.call_count += 1
        return self.response


class TestAutoSelectAgents:
    """Tests for auto_select_agents function."""

    def test_returns_valid_agent_names(self) -> None:
        """auto_select_agents returns agent names from the valid set."""
        response = json.dumps({
            "selected_agents": ["architect", "devops"],
            "reasoning": "Architecture and deployment are key.",
        })
        mock_client = MockLLMClient(response)
        result = auto_select_agents(
            topic="Deploy new service",
            context="",
            all_agents=SAMPLE_AGENTS,
            client=mock_client,
        )
        assert isinstance(result, list)
        assert all(name in SAMPLE_AGENTS for name in result)
        assert "architect" in result
        assert "devops" in result

    def test_respects_max_agents(self) -> None:
        """auto_select_agents limits to max_agents."""
        response = json.dumps({
            "selected_agents": ["architect", "devops", "business_analyst"],
            "reasoning": "All relevant.",
        })
        mock_client = MockLLMClient(response)
        result = auto_select_agents(
            topic="Test",
            context="",
            all_agents=SAMPLE_AGENTS,
            client=mock_client,
            max_agents=2,
        )
        assert len(result) <= 2

    def test_retries_on_invalid_json(self) -> None:
        """auto_select_agents retries when LLM returns invalid JSON."""
        call_count = 0

        class RetryMockClient:
            def chat(self, system: str, messages: list[dict[str, str]]) -> str:
                nonlocal call_count
                call_count += 1
                if call_count < 2:
                    return "not json"
                return json.dumps({
                    "selected_agents": ["architect"],
                    "reasoning": "OK.",
                })

        result = auto_select_agents(
            topic="Test",
            context="",
            all_agents=SAMPLE_AGENTS,
            client=RetryMockClient(),
        )
        assert result == ["architect"]
        assert call_count == 2

    def test_filters_invalid_agent_names(self) -> None:
        """auto_select_agents filters out agent names not in available set."""
        response = json.dumps({
            "selected_agents": ["architect", "nonexistent_agent"],
            "reasoning": "Mixed.",
        })
        mock_client = MockLLMClient(response)
        result = auto_select_agents(
            topic="Test",
            context="",
            all_agents=SAMPLE_AGENTS,
            client=mock_client,
        )
        assert "architect" in result
        assert "nonexistent_agent" not in result
