#!/usr/bin/env python3
"""
Gemini API Diagnostic Tool

Tests Gemini API with minimal payload to identify issues.
"""

import os
import sys
import json
sys.path.append(os.getcwd())

from health.utils.env_loader import load_env_with_extras
from health.utils.logging_config import setup_logger
from slack_bot.llm.gemini import GeminiLLM, get_system_instruction

logger = setup_logger(__name__)

def test_minimal():
    """Test with minimal context."""
    print("\n=== Test 1: Minimal Context ===")

    llm = GeminiLLM()

    context = []
    message = "Hello, can you hear me?"

    print(f"Model: {llm.get_model_name()}")
    print(f"Using proxy: {llm.use_proxy}")
    if llm.use_proxy:
        print(f"Base URL: {os.environ.get('GEMINI_BASE_URL')}")

    response, tool_calls = llm.generate_response(
        message=message,
        context=context,
        tools=None,
        images=None
    )

    print(f"Response: {response}")
    print(f"Tool calls: {tool_calls}")
    print(f"Response length: {len(response) if response else 0}")

    if not response or not response.strip():
        print("❌ EMPTY RESPONSE!")
        return False
    else:
        print("✅ Got response")
        return True


def test_with_tools():
    """Test with tools (simplified)."""
    print("\n=== Test 2: With Simple Tool ===")

    llm = GeminiLLM()

    context = []
    message = "What's 2+2?"

    simple_tool = {
        "type": "function",
        "function": {
            "name": "calculate",
            "description": "Calculate a math expression",
            "parameters": {
                "type": "object",
                "properties": {
                    "expression": {
                        "type": "string",
                        "description": "Math expression"
                    }
                },
                "required": ["expression"]
            }
        }
    }

    response, tool_calls = llm.generate_response(
        message=message,
        context=context,
        tools=[simple_tool],
        images=None
    )

    print(f"Response: {response}")
    print(f"Tool calls: {tool_calls}")

    if not response or not response.strip():
        print("❌ EMPTY RESPONSE!")
        return False
    else:
        print("✅ Got response")
        return True


def test_with_context():
    """Test with conversation context."""
    print("\n=== Test 3: With Context (10 messages) ===")

    llm = GeminiLLM()

    # Simulate 10 message conversation
    context = []
    for i in range(5):
        context.append({"role": "user", "content": f"Message {i*2}"})
        context.append({"role": "assistant", "content": f"Response {i*2}"})

    message = "Summarize our conversation"

    response, tool_calls = llm.generate_response(
        message=message,
        context=context,
        tools=None,
        images=None
    )

    print(f"Response: {response}")
    print(f"Response length: {len(response) if response else 0}")

    if not response or not response.strip():
        print("❌ EMPTY RESPONSE!")
        return False
    else:
        print("✅ Got response")
        return True


def test_system_prompt_size():
    """Test system prompt token size."""
    print("\n=== Test 4: System Prompt Analysis ===")

    system_prompt = get_system_instruction()

    print(f"System prompt length: {len(system_prompt)} chars")
    print(f"Estimated tokens: ~{len(system_prompt) // 3}")

    # Try to count with tiktoken if available
    try:
        import tiktoken
        encoder = tiktoken.get_encoding("cl100k_base")
        actual_tokens = len(encoder.encode(system_prompt))
        print(f"Actual tokens (tiktoken): {actual_tokens}")
    except:
        print("(tiktoken not available for precise count)")

    print("\n--- First 500 chars of system prompt ---")
    print(system_prompt[:500])
    print("\n--- Last 500 chars of system prompt ---")
    print(system_prompt[-500:])


def test_health_tools():
    """Test with actual health tools."""
    print("\n=== Test 5: With Health Tools (Full Schema) ===")

    from slack_bot.tools.registry import TOOLS_SCHEMA

    llm = GeminiLLM()

    context = []
    message = "今天睡得怎么样？"

    print(f"Number of tools: {len(TOOLS_SCHEMA)}")

    # Estimate tool schema size
    tools_json = json.dumps(TOOLS_SCHEMA, ensure_ascii=False)
    print(f"Tools schema length: {len(tools_json)} chars")
    print(f"Estimated tokens: ~{len(tools_json) // 3}")

    response, tool_calls = llm.generate_response(
        message=message,
        context=context,
        tools=TOOLS_SCHEMA,
        images=None
    )

    print(f"Response: {response}")
    print(f"Tool calls: {tool_calls}")

    if not response or not response.strip():
        print("❌ EMPTY RESPONSE!")
        return False
    else:
        print("✅ Got response")
        return True


def main():
    """Run all diagnostic tests."""
    load_env_with_extras()

    print("=" * 60)
    print("Gemini API Diagnostic Tool")
    print("=" * 60)

    print(f"\n📋 Configuration:")
    print(f"  GEMINI_MODEL: {os.environ.get('GEMINI_MODEL')}")
    print(f"  GEMINI_BASE_URL: {os.environ.get('GEMINI_BASE_URL', 'None (Direct Google API)')}")
    print(f"  GEMINI_API_KEY: {os.environ.get('GEMINI_API_KEY', 'Not set')[:20]}...")

    results = []

    try:
        results.append(("Minimal", test_minimal()))
    except Exception as e:
        print(f"❌ Test failed with error: {e}")
        results.append(("Minimal", False))

    try:
        results.append(("With Tools", test_with_tools()))
    except Exception as e:
        print(f"❌ Test failed with error: {e}")
        results.append(("With Tools", False))

    try:
        results.append(("With Context", test_with_context()))
    except Exception as e:
        print(f"❌ Test failed with error: {e}")
        results.append(("With Context", False))

    try:
        test_system_prompt_size()
    except Exception as e:
        print(f"❌ Test failed with error: {e}")

    try:
        results.append(("Health Tools", test_health_tools()))
    except Exception as e:
        print(f"❌ Test failed with error: {e}")
        results.append(("Health Tools", False))

    # Summary
    print("\n" + "=" * 60)
    print("📊 Test Summary:")
    print("=" * 60)

    for test_name, passed in results:
        status = "✅ PASS" if passed else "❌ FAIL"
        print(f"  {test_name}: {status}")

    total = len(results)
    passed = sum(1 for _, p in results if p)

    print(f"\n  Total: {passed}/{total} passed")

    if passed < total:
        print("\n⚠️  Some tests failed. This indicates an issue with:")
        print("     - Gemini API endpoint configuration")
        print("     - Tool schema format")
        print("     - Context/token limit")


if __name__ == "__main__":
    main()
