from typing import Optional, List, Dict, Any
from slack_sdk import WebClient
import os

from health.utils.logging_config import setup_logger
from slack_bot.llm.gemini import GeminiLLM
from slack_bot.context.storage import ContextStorage
from slack_bot.tools.registry import TOOLS_SCHEMA, TOOL_FUNCTIONS
from slack_bot.tools.groups import get_tool_preset

logger = setup_logger(__name__)


class MessageDispatcher:
    """Routes incoming Slack messages to Gemini and executes tools."""

    def __init__(self, bot_token: Optional[str] = None, system_instruction: Optional[str] = None, tools: Optional[List[Dict]] = None, tool_mode: str = "light"):
        """
        Initialize dispatcher.
        Args:
            bot_token: Optional specific Slack token
            system_instruction: Custom system prompt (e.g. for Shell Bot)
            tools: specific tools list to use (if None, uses preset based on tool_mode)
            tool_mode: 'none' (no tools), 'light' (5 tools), 'standard' (10 tools), 'full' (15 tools), 'all' (21 tools)
        """
        self.llm = GeminiLLM(system_instruction=system_instruction)

        # Use tool preset if tools not explicitly provided
        if tools is not None:
            self.tools = tools
            self.tool_functions = TOOL_FUNCTIONS
        elif tool_mode == "none":
            self.tools = None
            self.tool_functions = {}
            logger.info("Running in NO-TOOLS mode (pure text)")
        else:
            self.tools, self.tool_functions = get_tool_preset(tool_mode)
            logger.info(f"Using tool preset: {tool_mode} ({len(self.tools)} tools)")

        # Initialize Slack WebClient for posting replies
        self.bot_token = bot_token or os.environ.get("SLACK_BOT_TOKEN")
        self.client = WebClient(token=self.bot_token)

        logger.info("Dispatcher initialized with Gemini")

    def dispatch(self, message_text: str, channel_id: str, user_id: str, response_ts: Optional[str] = None, request_id: str = "N/A", files: Optional[List[Dict]] = None) -> None:
        """
        Process a message and trigger responses.
        """
        import time
        t0 = time.time()
        
        prefix = f"[{request_id}]"
        logger.info(f"{prefix} Processing message in {channel_id} from {user_id}")
        if files:
            logger.info(f"{prefix} Message contains {len(files)} files")

        # 1. Initialize Context
        storage = ContextStorage(channel_id)
        
        # Handle "clear" command
        if message_text.strip().lower() in ["clear", "reset", "清除"]:
            storage.clear()
            self.client.chat_postMessage(channel=channel_id, text="🧹 Context cleared. (对话历史已清除)")
            return

        storage.add_message("user", message_text)

        # Download images if any
        image_data_list = []
        download_failed = False
        if files:
            image_data_list = self._download_files(files, prefix)
            # Anti-Hallucination: Check if download failed
            if len(image_data_list) == 0:
                failure_msg = "\n[SYSTEM WARNING: User uploaded an image but it FAILED to download/process. You CANNOT see the image. DO NOT GUESS. Inform the user that the image access failed.]"
                logger.warning(f"{prefix} Image download failed! Appending warning to LLM.")
                message_text += failure_msg
                download_failed = True

        try:
            # 2. Get conversation context
            context = storage.get_context()
            logger.debug(f"{prefix} Context size: {len(context)}")

            # 3. Process with Gemini (with tools)
            # CRITICAL: If image download failed, disable tools to prevent "400 INVALID_ARGUMENT"
            # (which happens when mixing text-only warning prompts with complex tool schemas in some proxies)
            #
            # NOTE: gemini-3-flash has a bug where it returns empty response with tools,
            # but we use SAFETY OVERRIDE below to manually trigger tools based on keywords
            current_tools = None if download_failed else self.tools

            logger.info(f"{prefix} Calling Gemini for initial response...")
            logger.info(f"{prefix} Request config: model={self.llm.get_model_name()}, tools={len(current_tools) if current_tools else 0}, context={len(context)-1} msgs")
            t_llm_start = time.time()
            response_text, tool_calls = self.llm.generate_response(
                message=message_text,
                context=context[:-1],
                tools=current_tools,
                images=image_data_list
            )
            # Defensive check for None/empty response
            if response_text is None:
                logger.error(f"{prefix} Gemini returned None for response_text!")
                response_text = ""
            elif not response_text.strip():
                logger.warning(f"{prefix} Gemini returned empty response_text! (This usually means the proxy has issues)")

            # Initialize tool_calls
            if tool_calls is None:
                tool_calls = []

            # Log what we got back
            logger.info(f"{prefix} Gemini response: text={len(response_text)} chars, tool_calls={len(tool_calls) if tool_calls else 0}")

            # --- SAFETY OVERRIDE START ---
            # If the LLM failed to trigger tools (common with local proxies + tools bug),
            # manually detect intent and force tool calls based on keywords.
            # This is a FALLBACK mechanism when Gemini returns empty response or MALFORMED_FUNCTION_CALL.
            if not tool_calls:
                logger.info(f"{prefix} No tool calls from Gemini - checking SAFETY OVERRIDE keywords...")
                lower_msg = message_text.lower()

                # 1. Sync Garmin data
                if any(k in lower_msg for k in ["sync", "update", "fetch", "同步", "拉取", "更新"]) and \
                   any(k in lower_msg for k in ["garmin", "佳明", "data", "数据"]):
                    logger.warning(f"{prefix} 🔧 SAFETY OVERRIDE: Forcing sync_garmin tool")
                    tool_calls.append({
                        "name": "sync_garmin",
                        "args": {"target_date": None}
                    })

                # 2. Query health data (sleep, steps, activities, etc.)
                elif any(k in lower_msg for k in ["睡眠", "sleep", "步数", "steps", "心率", "heart", "hrv", "运动", "activity", "workout", "锻炼", "昨天", "yesterday", "今天", "today", "上午", "下午", "morning", "afternoon", "查询", "query", "变化", "趋势", "trend"]):
                    from datetime import datetime, timedelta

                    # CRITICAL: First check for single-day keywords to avoid false positives
                    has_single_day_keyword = any(k in lower_msg for k in [
                        "昨天", "昨晚", "昨日", "last night", "yesterday",
                        "今天", "今晚", "今日", "today", "tonight",
                        "前天", "前晚", "day before yesterday",
                        "大前天", "three days ago"
                    ])

                    # Check if asking for time-range trends (past X days/weeks/months)
                    is_trend_query = (not has_single_day_keyword) and any(k in lower_msg for k in [
                        "过去", "最近", "近", "past", "recent",
                        "变化", "趋势", "trend", "change", "历史", "history",
                        "个月", "month", "周", "week", "年", "year",
                    ])

                    if is_trend_query:
                        # Parse time range from query
                        import re
                        today = datetime.now()
                        end_date = today.strftime("%Y-%m-%d")
                        days_ago = 30  # default

                        if re.search(r'(\d+)\s*个?月|(\d+)\s*month', lower_msg):
                            months = int(re.search(r'(\d+)\s*个?月|(\d+)\s*month', lower_msg).group(1) or re.search(r'(\d+)\s*个?月|(\d+)\s*month', lower_msg).group(2))
                            days_ago = months * 30
                        elif re.search(r'(\d+)\s*周|(\d+)\s*week', lower_msg):
                            weeks = int(re.search(r'(\d+)\s*周|(\d+)\s*week', lower_msg).group(1) or re.search(r'(\d+)\s*周|(\d+)\s*week', lower_msg).group(2))
                            days_ago = weeks * 7
                        elif re.search(r'(\d+)\s*天|(\d+)\s*day', lower_msg):
                            days_ago = int(re.search(r'(\d+)\s*天|(\d+)\s*day', lower_msg).group(1) or re.search(r'(\d+)\s*天|(\d+)\s*day', lower_msg).group(2))
                        elif re.search(r'(\d+)\s*年|(\d+)\s*year', lower_msg):
                            years = int(re.search(r'(\d+)\s*年|(\d+)\s*year', lower_msg).group(1) or re.search(r'(\d+)\s*年|(\d+)\s*year', lower_msg).group(2))
                            days_ago = years * 365
                        elif "半年" in lower_msg or "six month" in lower_msg:
                            days_ago = 180
                        elif "三个月" in lower_msg or "three month" in lower_msg:
                            days_ago = 90

                        start_date = (today - timedelta(days=days_ago)).strftime("%Y-%m-%d")

                        # Determine metric type
                        metric_type = "rhr"  # default
                        if any(k in lower_msg for k in ["hrv", "心率变异"]):
                            metric_type = "hrv"
                        elif any(k in lower_msg for k in ["睡眠", "sleep"]):
                            metric_type = "sleep"
                        elif any(k in lower_msg for k in ["步数", "steps"]):
                            metric_type = "steps"
                        elif any(k in lower_msg for k in ["压力", "stress"]):
                            metric_type = "stress"

                        logger.warning(f"{prefix} 🔧 SAFETY OVERRIDE: Forcing get_metric_history for {metric_type} ({start_date} to {end_date})")
                        tool_calls.append({
                            "name": "get_metric_history",
                            "args": {
                                "metric_type": metric_type,
                                "start_date": start_date,
                                "end_date": end_date
                            }
                        })
                    else:
                        # Single-day query
                        if any(k in lower_msg for k in ["昨天", "昨晚", "昨日", "last night", "yesterday"]):
                            target = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
                        elif any(k in lower_msg for k in ["前天", "前晚", "day before yesterday"]):
                            target = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d")
                        else:
                            target = datetime.now().strftime("%Y-%m-%d")

                        # Check if specifically asking for activity history
                        if any(k in lower_msg for k in ["运动", "activity", "workout", "锻炼"]):
                            logger.warning(f"{prefix} 🔧 SAFETY OVERRIDE: Forcing get_activity_history for {target}")
                            tool_calls.append({
                                "name": "get_activity_history",
                                "args": {"start_date": target, "end_date": target}
                            })
                        else:
                            logger.warning(f"{prefix} 🔧 SAFETY OVERRIDE: Forcing get_daily_detailed_stats for {target}")
                            tool_calls.append({
                                "name": "get_daily_detailed_stats",
                                "args": {"target_date": target}
                            })

                # 3. Web search
                elif any(k in lower_msg for k in ["搜索", "search", "查一下", "look up", "最新", "latest"]):
                    query = message_text
                    logger.warning(f"{prefix} 🔧 SAFETY OVERRIDE: Forcing search_web")
                    tool_calls.append({
                        "name": "search_web",
                        "args": {"query": query}
                    })

                # 4. Detect diet logging intent (for logging, not forcing tool call yet)
                food_items = ["鸡", "肉", "菜", "蛋", "饭", "面", "粉", "汤", "鱼", "虾", "蟹", "牛", "猪", "羊", "豆", "奶"]
                meal_indicators = ["晚餐", "午餐", "早餐", "夜宵", "零食", "breakfast", "lunch", "dinner", "snack"]
                has_ate_pattern = any(k in lower_msg for k in ["吃了", "ate", "吃的", "eating"])
                has_meal_indicator = any(k in lower_msg for k in meal_indicators)
                has_food_item = any(k in lower_msg for k in food_items)

                if (has_ate_pattern and has_food_item) or has_meal_indicator:
                    logger.warning(f"{prefix} Detected diet logging intent: ate={has_ate_pattern}, meal={has_meal_indicator}, food={has_food_item}")
                    # Don't force tool - let user confirm first or LLM will handle in second round
                else:
                    logger.info(f"{prefix} No SAFETY OVERRIDE match found - will return empty response or text")
            else:
                logger.info(f"{prefix} Tool calls successfully generated by Gemini (no SAFETY OVERRIDE needed)")
            # --- SAFETY OVERRIDE END ---

            tool_count = len(tool_calls) if tool_calls else 0
            logger.info(f"{prefix} Gemini response in {time.time()-t_llm_start:.2f}s: {len(response_text)} chars, {tool_count} tools")

            # 4. Execute tools if requested
            tool_results = []
            if tool_calls:
                logger.info(f"{prefix} Executing {len(tool_calls)} tools")
                
                for tool_call in tool_calls:
                    tool_name = tool_call["name"]
                    tool_args = tool_call["args"]
                    logger.info(f"{prefix} Tool: {tool_name} Args: {tool_args}")
                    
                    t_tool_start = time.time()
                    if tool_name in self.tool_functions:
                        try:
                            # Inject channel_id for shell tool
                            if tool_name == "execute_shell":
                                tool_args["channel_id"] = channel_id

                            result = self.tool_functions[tool_name](**tool_args)
                            
                            # Safety Truncation: Prevent massive JSONs from blowing up context
                            str_result = str(result)
                            if len(str_result) > 8000:
                                str_result = str_result[:8000] + "... (truncated)"
                            
                            tool_results.append({
                                "tool": tool_name,
                                "args": tool_args,
                                "result": str_result
                            })
                            logger.info(f"{prefix} ✓ {tool_name} finished in {time.time()-t_tool_start:.2f}s")
                        except Exception as e:
                            error_msg = f"Error executing {tool_name}: {str(e)}"
                            logger.error(f"{prefix} {error_msg}")
                            tool_results.append({
                                "tool": tool_name,
                                "args": tool_args,
                                "result": f"❌ {error_msg}"
                            })
                    else:
                        logger.warning(f"{prefix} Unknown tool: {tool_name}")

                # 4b. Second round: Get final response with tool results
                logger.info(f"{prefix} Requesting final analysis...")

                tool_names = ", ".join([tr['tool'] for tr in tool_results])
                
                # Format specific tool results for context
                tool_results_text = "\n".join([
                    f"Tool '{tr['tool']}' (Args: {tr['args']}) returned:\n{tr['result']}"
                    for tr in tool_results
                ])

                # Inject thoughts into memory
                storage.add_message("assistant", f"I checked: {tool_names}", model="gemini")
                
                # Stronger Prompt for Analysis
                analysis_prompt = (
                    f"Here are the execution results from the tools:\n{tool_results_text}\n\n"
                    f"CRITICAL INSTRUCTION: You represent the 'Butler' bot. You must now answer the user's original question based ONLY on the above results.\n"
                    f"1. Summarize the key data points found (e.g., steps, sleep score).\n"
                    f"2. Provide direct, natural language insights.\n"
                    f"3. If the tool indicated an error, explain it to the user.\n\n"
                    f"Do NOT output raw JSON. Write a helpful paragraph."
                )
                
                # Do NOT add this massive data blob to persistent storage.
                # Just use it for this generation context.
                final_context = storage.get_context()
                
                t_llm_2_start = time.time()
                response_text, _ = self.llm.generate_response(
                    message=analysis_prompt,
                    context=final_context,
                    tools=None
                )
                logger.info(f"{prefix} Final analysis received in {time.time()-t_llm_2_start:.2f}s")

                # Fallback
                if not response_text or not response_text.strip():
                    logger.warning(f"{prefix} Empty analysis received!")
                    if files:
                        response_text = "⚠️ 图片已接收，但 AI 模型未能返回描述。这可能是由于网络波动或模型暂时无法识别该图像，请稍后重试。"
                    else:
                        response_text = "✅ 数据已同步，但我似乎暂时无法生成文字分析。请参考上方的工具执行结果。"

            # 5. Format and send response
            # Relaxing internal limit to allow for multi-message splitting (max 3 messages of ~4000 chars)
            APP_LIMIT = 12000 
            full_response = self._format_response(response_text, tool_results, max_length=APP_LIMIT)

            # Split into chunks
            chunks = self._split_text(full_response, max_chunk_size=3800)
            
            # Limit to 3 chunks as requested
            if len(chunks) > 3:
                chunks = chunks[:3]
                chunks[-1] += "\n...(remaining content truncated)"

            # Post to Slack
            if response_ts:
                # Update the first message (Thinking...)
                try:
                    self.client.chat_update(channel=channel_id, ts=response_ts, text=chunks[0])
                    logger.info(f"{prefix} 📤 Message updated {response_ts}")
                except Exception as e:
                    logger.warning(f"{prefix} Update failed, posting new: {e}")
                    self.client.chat_postMessage(channel=channel_id, text=chunks[0])
                
                # Post subsequent chunks as new messages
                for chunk in chunks[1:]:
                    # Small delay to ensure order
                    time.sleep(0.5)
                    self.client.chat_postMessage(channel=channel_id, text=chunk)
                    logger.info(f"{prefix} 📤 Continuation message posted")

            else:
                # Post all as new messages
                for i, chunk in enumerate(chunks):
                    if i > 0: time.sleep(0.5)
                    self.client.chat_postMessage(channel=channel_id, text=chunk)
                logger.info(f"{prefix} 📤 {len(chunks)} message(s) posted")


            # 6. Save to Context
            storage.add_message("assistant", full_response, model="gemini")
            logger.info(f"{prefix} Dispatch completed in {time.time()-t0:.2f}s")

        except Exception as e:
            logger.error(f"{prefix} Dispatch failed: {e}", exc_info=True)
            try:
                self.client.chat_postMessage(
                    channel=channel_id,
                    text=f"⚠️ Internal Error: {str(e)}"
                )
                logger.info("Error message posted to Slack")
            except Exception as slack_error:
                logger.error(f"Failed to post error message to Slack: {slack_error}", exc_info=True)

    def _split_text(self, text: str, max_chunk_size: int = 3800) -> List[str]:
        """Split text into chunks avoiding breaking code blocks if possible."""
        if len(text) <= max_chunk_size:
            return [text]
            
        chunks = []
        current_chunk = ""
        lines = text.split('\n')
        
        in_code_block = False
        
        for line in lines:
            # Check for code block toggle
            if line.strip().startswith('```'):
                in_code_block = not in_code_block
                
            # If adding this line exceeds chunk size
            if len(current_chunk) + len(line) + 1 > max_chunk_size:
                # Close code block if open
                if in_code_block:
                    current_chunk += "\n```"
                
                chunks.append(current_chunk)
                
                # Start new chunk
                current_chunk = ""
                # Re-open code block if it was open
                if in_code_block:
                    current_chunk += "```\n" + "(...continued)\n"
            
            if current_chunk:
                current_chunk += "\n"
            current_chunk += line
            
        if current_chunk:
            chunks.append(current_chunk)
            
        return chunks

    def _format_response(self, text: str, tool_results: list, max_length: int = 39900) -> str:
        """
        Format the response with tool execution results.

        Args:
            text: The main response text
            tool_results: List of tool execution results
            max_length: Maximum allowed message length (default for new messages)

        Returns:
            Formatted response string, guaranteed to be under max_length
        """
        parts = []

        # Add tool execution results first (with smart truncation)
        if tool_results:
            parts.append("🛠️ *Tool Executions:*")
            for tr in tool_results:
                # Format args nicely
                args_str = ", ".join(f"{k}={v}" for k, v in tr['args'].items())

                # Truncate individual tool result if needed (more aggressive for multiple tools)
                result_str = str(tr['result'])
                # Allow larger tool outputs since we have multi-message support
                max_per_tool = 5000 
                if len(result_str) > max_per_tool:
                    result_str = result_str[:max_per_tool] + "... (truncated)"

                parts.append(f"• `{tr['tool']}({args_str})`: {result_str}")
            parts.append("")  # Blank line

        # Add text response with formatting
        if text:
            from slack_bot.utils.mrkdwn import SlackFormatter
            formatted_text = SlackFormatter.convert(text)
            logger.info(f"Formatted text (len {len(text)} -> {len(formatted_text)})")
            parts.append(formatted_text)

        # Join all parts
        if parts:
            full_response = "\n".join(parts)
        else:
            # No tools and no response - this is unusual
            logger.warning("No response text and no tool results - returning error message")

            # Enhanced debugging info for empty Gemini responses
            debug_info = f"\n\n🔍 调试信息:\n- 响应文本长度: {len(text)} 字符\n- 工具调用数: {len(tool_results)}\n- 模型: {self.llm.get_model_name()}"

            full_response = (
                "⚠️ 抱歉，我似乎没有理解你的请求。请尝试更具体的描述，或者直接说明你想查询什么数据（例如：今天的步数、昨天的睡眠、上午的运动等）。"
                + debug_info
            )

        # Final safety check: ensure we're under the limit
        if len(full_response) > max_length:
            truncation_msg = "\n\n⚠️ (Response truncated due to Slack length limit)"
            allowed_length = max_length - len(truncation_msg)
            full_response = full_response[:allowed_length] + truncation_msg
            logger.warning(f"Final response truncated from {len(full_response)} to {max_length} chars")

        return full_response

    def _download_files(self, files: List[Dict], prefix: str) -> List[Dict[str, Any]]:
        """Download files from Slack."""
        import requests
        
        token = self.bot_token
        from PIL import Image
        import io
        
        downloaded_files = []
        
        for file_info in files:
            logger.info(f"{prefix} found file: {file_info.get('name')} Type: {file_info.get('mimetype')} URL_P: {bool(file_info.get('url_private'))}")
            # Check if it's an image
            if 'mimetype' in file_info and file_info['mimetype'].startswith('image/'):
                url = file_info.get('url_private_download') or file_info.get('url_private')
                if not url:
                    continue
                    
                headers = {'Authorization': f'Bearer {token}'}
                # Debug Auth (Masked)
                token_masked = f"{token[:5]}...{token[-5:]}" if token else "None"
                logger.info(f"{prefix} Attempting download from {url} with token {token_masked}")
                
                # Try using url_private with token in query param
                # This often bypasses redirect auth stripping issues
                base_url = file_info.get('url_private')
                if not base_url:
                    continue
                
                # Use Session to handle cookies/redirects better
                session = requests.Session()
                session.headers.update({'Authorization': f'Bearer {token}'})
                
                # Debug Auth (Masked)
                token_masked = f"{token[:5]}...{token[-5:]}" if token else "None"
                logger.info(f"{prefix} Attempting download from {url} using Session (Token: {token_masked})")
                
                try:
                    response = session.get(url, timeout=10)
                    
                    # Log final URL to check for login page redirect
                    if response.url != url:
                        logger.info(f"{prefix} Redirected to: {response.url}")

                    if response.status_code == 200:
                        # DEBUG: Check for HTML (Login page)
                        ct = response.headers.get('Content-Type', 'unknown')
                        if 'html' in ct.lower():
                            logger.error(f"{prefix} Download returned HTML (likely Login Page) instead of image. Auth failed.")
                            continue

                        # Process image: Resize and Compress
                        try:
                            img = Image.open(io.BytesIO(response.content))
                        except Exception as img_err:
                            logger.error(f"{prefix} PIL Open Failed! Content preamble: {response.content[:200]!r}")
                            continue # Skip this file
                        
                        # Convert to RGB if necessary (e.g., PNG with alpha)
                        if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
                            bg = Image.new('RGB', img.size, (255, 255, 255))
                            if img.mode != 'RGBA':
                                img = img.convert('RGBA')
                            bg.paste(img, mask=img.split()[3])
                            img = bg
                        elif img.mode != 'RGB':
                            img = img.convert('RGB')
                            
                        # Resize if too large (max dimension 1024)
                        max_dim = 1024
                        if max(img.size) > max_dim:
                            img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
                        
                        # Compress to JPEG
                        buf = io.BytesIO()
                        img.save(buf, format='JPEG', quality=85)
                        optimized_data = buf.getvalue()
                        
                        logger.info(f"{prefix} Downloaded and optimized image: {file_info.get('name')} "
                                  f"({len(response.content)/1024:.1f}KB -> {len(optimized_data)/1024:.1f}KB)")
                        
                        downloaded_files.append({
                            "mime_type": "image/jpeg", # Always normalize to JPEG for consistency
                            "data": optimized_data
                        })
                    else:
                        logger.error(f"{prefix} Failed to download file: status {response.status_code}")
                except Exception as e:
                    logger.error(f"{prefix} Error downloading file: {e}")
            
        return downloaded_files
