import os
import random
import re
from typing import List, Dict, Optional, Set
from health.utils.logging_config import setup_logger

logger = setup_logger(__name__)

class ObsidianIndexer:
    """
    Scans an Obsidian Vault to index files and extract samples based on tags.
    """
    
    WRITING_SAMPLE_TAG = "#writing_sample"
    REPLY_SAMPLE_TAG = "#reply_sample"

    def __init__(self, vault_path: str):
        self.vault_path = os.path.expanduser(vault_path) if vault_path else None
        if not self.vault_path or not os.path.exists(self.vault_path):
            logger.warning(f"Obsidian Vault path invalid: {self.vault_path}")
            
        self.writing_samples: List[str] = []
        self.reply_samples: List[str] = []
        self.file_index: Dict[str, str] = {} # filepath -> lowercase content (for simple search)
        self.file_metadata: Dict[str, float] = {} # filepath -> mtime
        self.files_scanned = 0

    def scan_vault(self):
        """Scans the vault for markdown files and populates indexes."""
        if not self.vault_path or not os.path.exists(self.vault_path):
            logger.error(f"Cannot scan: Invalid vault path {self.vault_path}")
            return

        logger.info(f"Scanning Obsidian Vault at: {self.vault_path}")
        self.writing_samples = []
        self.reply_samples = []
        self.file_index = {}
        self.file_metadata = {}
        self.files_scanned = 0
        
        for root, dirs, files in os.walk(self.vault_path):
            # Skip hidden folders (like .obsidian, .git)
            dirs[:] = [d for d in dirs if not d.startswith('.')]
            
            for file in files:
                if file.endswith(".md"):
                    full_path = os.path.join(root, file)
                    self._process_file(full_path)
                    self.files_scanned += 1

        logger.info(f"Scan complete. Scanned {self.files_scanned} files.")
        logger.info(f"Found {len(self.writing_samples)} writing samples.")
        logger.info(f"Found {len(self.reply_samples)} reply samples.")

    def _process_file(self, file_path: str):
        """Reads a file and checks for tags."""
        try:
            # Capture metadata
            mtime = os.path.getmtime(file_path)
            self.file_metadata[file_path] = mtime

            with open(file_path, "r", encoding="utf-8") as f:
                content = f.read()
                
            # Basic tag detection (handles #tag inside text)
            # We don't need a full-blown frontmatter parser for this, 
            # as Obsidian tags can be anywhere.
            
            if self.WRITING_SAMPLE_TAG in content:
                self.writing_samples.append(file_path)
                
            if self.REPLY_SAMPLE_TAG in content:
                self.reply_samples.append(file_path)

            # Store for search (naive in-memory index)
            # Truncating huge files to avoid memory boom if necessary, 
            # but for text notes 100k char limit is generous.
            self.file_index[file_path] = content[:100000].lower()

        except Exception as e:
            logger.warning(f"Failed to read file {file_path}: {e}")

    def get_writing_samples(self, count: int = 3) -> List[str]:
        """Returns content of random writing samples."""
        if not self.writing_samples:
            return []
        
        selected_paths = random.sample(self.writing_samples, min(len(self.writing_samples), count))
        return self._read_files(selected_paths)

    def get_reply_samples(self, count: int = 3) -> List[str]:
        """Returns content of random reply samples."""
        if not self.reply_samples:
            return []
            
        selected_paths = random.sample(self.reply_samples, min(len(self.reply_samples), count))
        return self._read_files(selected_paths)

    def search(self, query: str, limit: int = 5) -> List[str]:
        """
        Simple keyword search. Returns content of matching files.
        RAG Strategy: Find notes containing the query keywords.
        """
        query_lower = query.lower()
        
        # Split query into tokens
        # Improved regex to handle mixed English/Chinese (e.g. "XLSmart遇到了")
        # Matches:
        # 1. English/Number/Underscore sequences: [a-zA-Z0-9_]+
        # 2. Chinese characters (unigrams): [\u4e00-\u9fff]
        tokens = re.findall(r'[a-zA-Z0-9_]+|[\u4e00-\u9fff]', query_lower)
        if not tokens:
            return []
            
        logger.debug(f"Search tokens: {tokens}")
            
        # Naive scoring: count token occurrences
        scored_files = []
        for path, content in self.file_index.items():
            score = 0
            unique_matches = 0
            
            # Combine filename and content for search context
            filename = os.path.basename(path).lower()
            search_text = f"{filename} {filename} {content}" # Weight filename higher
            
            for token in tokens:
                if token in search_text:
                    # Cubic weighting by length to heavily favor longer keywords (like "XLSmart")
                    # over common single characters.
                    count = search_text.count(token)
                    score += count * (len(token) ** 3)
                    unique_matches += 1
            
            # Boost score based on how many unique tokens matched
            if unique_matches > 0:
                score = score * (unique_matches ** 2)
                scored_files.append((score, path))
        
        # Sort by score desc
        scored_files.sort(key=lambda x: x[0], reverse=True)
        
        top_paths = [path for _, path in scored_files[:limit]]
        return self._read_files(top_paths)
    
    def get_recent_files(self, days: int = 5, limit: int = 10) -> str:
        """
        Returns a formatted list of files modified in the last N days.
        """
        import time
        current_time = time.time()
        cutoff_time = current_time - (days * 86400)
        
        recent_files = []
        for path, mtime in self.file_metadata.items():
            if mtime >= cutoff_time:
                recent_files.append((mtime, path))
                
        # Sort by mtime DESC (newest first)
        recent_files.sort(key=lambda x: x[0], reverse=True)
        
        # Limit
        recent_files = recent_files[:limit]
        
        if not recent_files:
            return f"No files modified in the last {days} days."
            
        output = [f"📂 **Updated in last {days} days:**"]
        for mtime, path in recent_files:
            filename = os.path.basename(path)
            date_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(mtime))
            output.append(f"- `{filename}` ({date_str})")
            
        return "\n".join(output)

    def _read_files(self, paths: List[str]) -> List[str]:
        """Helper to read complete content of list of paths."""
        contents = []
        for p in paths:
            try:
                with open(p, "r", encoding="utf-8") as f:
                    # Provide context about the source
                    filename = os.path.basename(p)
                    contents.append(f"--- Source: {filename} ---\n{f.read()}")
            except Exception:
                pass
        return contents
