"""ChromaDB vector store for the Second Brain RAG system."""

import hashlib
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Optional

from pydantic import BaseModel, Field

from health.utils.logging_config import setup_logger
from slack_bot.obsidian.embeddings import EmbeddingProvider, get_embedding_provider

logger = setup_logger(__name__)


class ChunkMetadata(BaseModel):
    """Metadata for a document chunk stored in the vector DB."""

    source_path: str
    chunk_index: int
    header_hierarchy: str
    note_title: Optional[str] = None
    tags: List[str] = Field(default_factory=list)
    ingested_at: str  # ISO datetime string


class SearchResult(BaseModel):
    """A single semantic search result."""

    text: str
    source_path: str
    header_hierarchy: str
    distance: float
    note_title: Optional[str] = None


def split_md_by_headers(body: str) -> List[str]:
    """Split a markdown body by H2/H3 headers into chunks.

    Args:
        body: Markdown text to split.

    Returns:
        List of non-empty chunks of at least 50 characters.
    """
    raw_chunks = re.split(r'\n(?=#{2,3} )', body)
    return [chunk.strip() for chunk in raw_chunks if len(chunk.strip()) >= 50]


class ChromaVectorStore:
    """ChromaDB-backed vector store for Obsidian notes."""

    COLLECTION_NAME = "obsidian_notes"

    def __init__(
        self,
        embedding_provider: Optional[EmbeddingProvider] = None,
        db_path: Optional[Path] = None,
    ) -> None:
        """Initialize the ChromaDB vector store.

        Args:
            embedding_provider: Provider for generating embeddings. Defaults to
                get_embedding_provider().
            db_path: Path to ChromaDB persistence directory. Defaults to
                DATA_DIR/vector_db.
        """
        if embedding_provider is None:
            embedding_provider = get_embedding_provider()
        self.embedding_provider = embedding_provider

        if db_path is None:
            from health import config
            db_path = config.DATA_DIR / "vector_db"

        db_path.mkdir(parents=True, exist_ok=True)
        self.db_path = db_path

        try:
            import chromadb
            self._client = chromadb.PersistentClient(path=str(db_path))
            self._collection = self._client.get_or_create_collection(
                name=self.COLLECTION_NAME,
                metadata={"hnsw:space": "cosine"},
            )
            logger.info(
                f"ChromaVectorStore initialized at {db_path} "
                f"(collection: {self.COLLECTION_NAME})"
            )
        except ImportError:
            raise ImportError("chromadb is not installed. Run: pip install chromadb>=0.5.0")

    def _make_doc_id(self, source_path: str, chunk_index: int) -> str:
        """Generate a deterministic 16-char document ID for idempotent upsert.

        Args:
            source_path: Path to the source file.
            chunk_index: Index of the chunk within the file.

        Returns:
            16-character hex string.
        """
        raw = f"{source_path}:{chunk_index}"
        return hashlib.sha256(raw.encode()).hexdigest()[:16]

    def add_chunks(self, texts: List[str], metadatas: List[ChunkMetadata]) -> None:
        """Add or update chunks in the vector store (idempotent upsert).

        Args:
            texts: List of text chunks.
            metadatas: Corresponding ChunkMetadata for each chunk.
        """
        if not texts:
            return

        doc_ids = [self._make_doc_id(m.source_path, m.chunk_index) for m in metadatas]
        embeddings = self.embedding_provider.embed(texts)

        # Flatten list fields: ChromaDB metadata values must be str/int/float/bool
        meta_dicts = []
        for m in metadatas:
            d = m.model_dump()
            d["tags"] = ",".join(d["tags"])
            meta_dicts.append(d)

        self._collection.upsert(
            ids=doc_ids,
            embeddings=embeddings,
            documents=texts,
            metadatas=meta_dicts,
        )
        logger.debug(f"Upserted {len(texts)} chunks into vector store")

    def search_knowledge(self, query: str, top_k: int = 3) -> List[SearchResult]:
        """Semantic search over stored chunks.

        Args:
            query: Search query text.
            top_k: Maximum number of results to return.

        Returns:
            List of SearchResult ordered by relevance (lowest distance first).
        """
        total = self._collection.count()
        if total == 0:
            return []

        query_embedding = self.embedding_provider.embed([query])[0]
        results = self._collection.query(
            query_embeddings=[query_embedding],
            n_results=min(top_k, total),
            include=["documents", "metadatas", "distances"],
        )

        search_results: List[SearchResult] = []
        if not results["ids"] or not results["ids"][0]:
            return search_results

        for doc, meta, dist in zip(
            results["documents"][0],
            results["metadatas"][0],
            results["distances"][0],
        ):
            search_results.append(
                SearchResult(
                    text=doc,
                    source_path=meta.get("source_path", ""),
                    header_hierarchy=meta.get("header_hierarchy", ""),
                    distance=dist,
                    note_title=meta.get("note_title"),
                )
            )

        return search_results

    def reset(self) -> None:
        """Delete and recreate the collection, removing all stored data."""
        self._client.delete_collection(self.COLLECTION_NAME)
        self._collection = self._client.get_or_create_collection(
            name=self.COLLECTION_NAME,
            metadata={"hnsw:space": "cosine"},
        )
        logger.info(f"Collection '{self.COLLECTION_NAME}' reset")

    def get_stats(self) -> dict:
        """Return basic collection statistics.

        Returns:
            Dict with collection_name, total_chunks, and db_path.
        """
        return {
            "collection_name": self.COLLECTION_NAME,
            "total_chunks": self._collection.count(),
            "db_path": str(self.db_path),
        }
