"""Embedding providers for the Second Brain RAG system."""

import os
from abc import ABC, abstractmethod
from typing import List

from health.utils.logging_config import setup_logger

logger = setup_logger(__name__)


class EmbeddingProvider(ABC):
    """Abstract base class for embedding providers."""

    @abstractmethod
    def embed(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings for a list of texts.

        Args:
            texts: List of text strings to embed.

        Returns:
            List of embedding vectors (one per text).
        """
        ...

    @abstractmethod
    def get_dimension(self) -> int:
        """Return the embedding vector dimension.

        Returns:
            Embedding vector dimension size.
        """
        ...


class QwenEmbeddingProvider(EmbeddingProvider):
    """Cloud embedding provider using Qwen via OpenAI-compatible API."""

    def __init__(self) -> None:
        api_key = os.environ.get("EMBEDDING_API_KEY")
        base_url = os.environ.get("EMBEDDING_BASE_URL", "https://openrouter.ai/api")
        self.model = os.environ.get("EMBEDDING_MODEL", "qwen/qwen3-embedding-8b")

        if not api_key:
            raise ValueError("EMBEDDING_API_KEY is required for cloud embedding backend")

        from openai import OpenAI
        self.client = OpenAI(api_key=api_key, base_url=base_url)

    def embed(self, texts: List[str]) -> List[List[float]]:
        """Embed texts using Qwen via OpenAI-compatible API.

        Args:
            texts: List of text strings to embed.

        Returns:
            List of embedding vectors.
        """
        response = self.client.embeddings.create(input=texts, model=self.model)
        return [item.embedding for item in response.data]

    def get_dimension(self) -> int:
        """Return embedding dimension for Qwen embedding model.

        Returns:
            Embedding dimension (4096).
        """
        return 4096


class SentenceTransformerProvider(EmbeddingProvider):
    """Local embedding provider using sentence-transformers (lazy-loaded)."""

    def __init__(self) -> None:
        self.model_name = os.environ.get("EMBEDDING_LOCAL_MODEL", "all-MiniLM-L6-v2")
        self._model = None  # lazy-load on first use

    def _get_model(self):
        """Lazy-load the SentenceTransformer model."""
        if self._model is None:
            try:
                from sentence_transformers import SentenceTransformer
                logger.info(f"Loading local embedding model: {self.model_name}")
                self._model = SentenceTransformer(self.model_name)
            except ImportError:
                raise ImportError(
                    "sentence-transformers is not installed.\n"
                    "Run: pip install sentence-transformers\n"
                    "Or set EMBEDDING_BACKEND=cloud to use a remote API instead."
                )
        return self._model

    def embed(self, texts: List[str]) -> List[List[float]]:
        """Embed texts using a local SentenceTransformer model.

        Args:
            texts: List of text strings to embed.

        Returns:
            List of embedding vectors.
        """
        model = self._get_model()
        embeddings = model.encode(texts, convert_to_numpy=True)
        return embeddings.tolist()

    def get_dimension(self) -> int:
        """Return embedding dimension for all-MiniLM-L6-v2.

        Returns:
            Embedding dimension (384).
        """
        return 384


def get_embedding_provider() -> EmbeddingProvider:
    """Factory: return the configured embedding provider.

    Reads EMBEDDING_BACKEND env var (default: "local").

    Returns:
        Configured EmbeddingProvider instance.
    """
    backend = os.environ.get("EMBEDDING_BACKEND", "local").lower()
    if backend == "cloud":
        logger.info("Using cloud (Qwen) embedding provider")
        return QwenEmbeddingProvider()
    logger.info("Using local (SentenceTransformer) embedding provider")
    return SentenceTransformerProvider()
