#!/usr/bin/env python3
"""
Bulk vault indexer for Second Brain RAG system.

Scans an Obsidian vault directory, splits all .md files into chunks by H2/H3
headers, embeds them, and upserts into ChromaDB.

Usage:
    python scripts/init_vault_indexer.py --path /path/to/vault [--reset]

Options:
    --path   Path to vault directory (or sub-directory) to index. Required.
    --reset  Drop and recreate the ChromaDB collection before indexing.

Examples:
    # Index the full vault
    python scripts/init_vault_indexer.py --path /root/vault/obsidian_vault/obsidian/obsidian

    # Index only daily notes, starting fresh
    python scripts/init_vault_indexer.py --path /root/vault/obsidian_vault/obsidian/obsidian/daily --reset
"""

import sys
import time
from datetime import datetime, timezone
from pathlib import Path

# Add project root to sys.path so imports resolve correctly
sys.path.insert(0, str(Path(__file__).parent.parent))

import click

from health.utils.env_loader import load_env_with_extras
from health.utils.logging_config import setup_logger
from slack_bot.obsidian.embeddings import get_embedding_provider
from slack_bot.obsidian.vector_store import ChromaVectorStore, ChunkMetadata, split_md_by_headers

logger = setup_logger(__name__, log_file="init_vault_indexer.log")

# Directories to skip when scanning the vault
_SKIP_DIRS = {".obsidian", "attachments", ".trash", "venv", "node_modules"}

_MAX_RETRIES = 3
_RETRY_BASE_DELAY = 2.0  # seconds; doubles on each retry


def _iter_md_files(root: Path):
    """Yield .md files under root, skipping excluded directories.

    Args:
        root: Root directory to scan.

    Yields:
        Path objects for each .md file found.
    """
    for p in root.rglob("*.md"):
        if any(skip in p.parts for skip in _SKIP_DIRS):
            continue
        yield p


def _embed_with_retry(store: ChromaVectorStore, texts: list[str], metadatas: list[ChunkMetadata]) -> None:
    """Call store.add_chunks() with exponential-backoff retry.

    Args:
        store: ChromaVectorStore instance.
        texts: List of text chunks.
        metadatas: Corresponding chunk metadata.

    Raises:
        Exception: If all retries are exhausted.
    """
    for attempt in range(1, _MAX_RETRIES + 1):
        try:
            store.add_chunks(texts, metadatas)
            return
        except Exception as e:
            if attempt == _MAX_RETRIES:
                raise
            delay = _RETRY_BASE_DELAY * (2 ** (attempt - 1))
            logger.warning(f"Embed attempt {attempt} failed ({e}). Retrying in {delay:.1f}s...")
            time.sleep(delay)


@click.command()
@click.option(
    "--path",
    required=True,
    type=click.Path(exists=True),
    help="Vault directory (or sub-directory) to index.",
)
@click.option(
    "--reset",
    is_flag=True,
    default=False,
    help="Drop and recreate the ChromaDB collection before indexing.",
)
@click.option(
    "--delay",
    default=0.5,
    type=float,
    show_default=True,
    help="Seconds to sleep between files (throttles API rate).",
)
def main(path: str, reset: bool, delay: float) -> None:
    """Bulk-index Obsidian vault notes into ChromaDB.

    Args:
        path: Vault directory to scan.
        reset: Whether to reset the collection first.
        delay: Sleep time in seconds between each file to avoid rate limiting.
    """
    try:
        from tqdm import tqdm
    except ImportError:
        raise SystemExit("tqdm is not installed. Run: pip install tqdm>=4.66.0")

    load_env_with_extras()

    start_time = datetime.now(tz=timezone.utc)
    logger.info("=" * 60)
    logger.info(f"init_vault_indexer started: {start_time.isoformat()}")
    logger.info(f"Vault path: {path}")
    logger.info(f"Reset: {reset}")
    logger.info("=" * 60)

    store = ChromaVectorStore(embedding_provider=get_embedding_provider())

    if reset:
        store.reset()
        logger.info("Collection reset complete")
        click.echo("Collection reset.")

    vault_root = Path(path)
    md_files = list(_iter_md_files(vault_root))
    click.echo(f"Found {len(md_files)} .md files under {vault_root}")
    logger.info(f"Found {len(md_files)} .md files")

    skipped = 0
    indexed_files = 0
    total_chunks = 0

    for md_file in tqdm(md_files, desc="Indexing", unit="file"):
        try:
            import frontmatter  # python-frontmatter

            post = frontmatter.load(str(md_file))
            body: str = post.content
            meta = post.metadata

            note_title: str = str(meta.get("title", md_file.stem))
            tags: list[str] = list(meta.get("tags", []))

            chunks = split_md_by_headers(body)
            if not chunks:
                logger.debug(f"No usable chunks in {md_file}, skipping")
                skipped += 1
                continue

            ingested_at = datetime.now(tz=timezone.utc).isoformat()
            chunk_metas = [
                ChunkMetadata(
                    source_path=str(md_file),
                    chunk_index=i,
                    header_hierarchy=(
                        chunk.split('\n')[0][:100] if chunk.startswith('#') else "body"
                    ),
                    note_title=note_title,
                    tags=tags,
                    ingested_at=ingested_at,
                )
                for i, chunk in enumerate(chunks)
            ]

            _embed_with_retry(store, chunks, chunk_metas)

            indexed_files += 1
            total_chunks += len(chunks)

            if delay > 0:
                time.sleep(delay)

        except Exception as e:
            logger.warning(f"Skipped {md_file}: {e}")
            skipped += 1

    stats = store.get_stats()
    end_time = datetime.now(tz=timezone.utc)
    duration = (end_time - start_time).total_seconds()

    summary = (
        f"\n✅ Indexing complete in {duration:.1f}s\n"
        f"   Files indexed : {indexed_files}\n"
        f"   Files skipped : {skipped}\n"
        f"   Chunks added  : {total_chunks}\n"
        f"   Total in DB   : {stats['total_chunks']}\n"
        f"   DB path       : {stats['db_path']}"
    )
    click.echo(summary)
    logger.info(summary)


if __name__ == "__main__":
    main()
