from typing import Optional, Sequence
from types import TracebackType
from uuid import UUID

from overrides import override
import httpx
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
from chromadb.api.collection_configuration import (
    CreateCollectionConfiguration,
    UpdateCollectionConfiguration,
    validate_embedding_function_conflict_on_create,
    validate_embedding_function_conflict_on_get,
)
from chromadb.api.shared_system_client import SharedSystemClient
from chromadb.api.types import (
    CollectionMetadata,
    DataLoader,
    Documents,
    Embeddable,
    EmbeddingFunction,
    Embeddings,
    GetResult,
    IDs,
    Include,
    Loadable,
    Metadatas,
    QueryResult,
    Schema,
    URIs,
    IncludeMetadataDocuments,
    IncludeMetadataDocumentsDistances,
    DefaultEmbeddingFunction,
)
from chromadb.auth import UserIdentity
from chromadb.auth.utils import maybe_set_tenant_and_database
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
from chromadb.api.models.Collection import Collection
from chromadb.errors import ChromaAuthError, ChromaError
from chromadb.types import Database, Tenant, Where, WhereDocument


class Client(SharedSystemClient, ClientAPI):
    """A client for Chroma. This is the main entrypoint for interacting with Chroma.
    A client internally stores its tenant and database and proxies calls to a
    Server API instance of Chroma. It treats the Server API and corresponding System
    as a singleton, so multiple clients connecting to the same resource will share the
    same API instance.

    Client implementations should be implement their own API-caching strategies.
    """

    tenant: str = DEFAULT_TENANT
    database: str = DEFAULT_DATABASE

    _server: ServerAPI
    # An internal admin client for verifying that databases and tenants exist
    _admin_client: AdminAPI
    _closed: bool = False

    # region Initialization
    def __init__(
        self,
        tenant: Optional[str] = DEFAULT_TENANT,
        database: Optional[str] = DEFAULT_DATABASE,
        settings: Settings = Settings(),
    ) -> None:
        super().__init__(settings=settings)
        try:
            if tenant is not None:
                self.tenant = tenant
            if database is not None:
                self.database = database

            # Get the root system component we want to interact with
            self._server = self._system.instance(ServerAPI)

            user_identity = self.get_user_identity()

            maybe_tenant, maybe_database = maybe_set_tenant_and_database(
                user_identity,
                overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
                user_provided_tenant=tenant,
                user_provided_database=database,
            )

            # this should not happen unless types are invalidated
            if maybe_tenant is None and tenant is None:
                raise ChromaAuthError(
                    "Could not determine a tenant from the current authentication method. Please provide a tenant."
                )
            if maybe_database is None and database is None:
                raise ChromaAuthError(
                    "Could not determine a database name from the current authentication method. Please provide a database name."
                )

            if maybe_tenant:
                self.tenant = maybe_tenant
            if maybe_database:
                self.database = maybe_database

            # Create an admin client for verifying that databases and tenants exist
            self._admin_client = AdminClient.from_system(self._system)
            self._validate_tenant_database(tenant=self.tenant, database=self.database)

            self._submit_client_start_event()
        except Exception:
            # If init fails after refcount was incremented, release references
            # to avoid a resource leak (the caller never receives the object to
            # call close() on it).
            if hasattr(self, "_admin_client"):
                SharedSystemClient._release_system(self._admin_client._identifier)
            SharedSystemClient._release_system(self._identifier)
            raise

    @classmethod
    @override
    def from_system(
        cls,
        system: System,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> "Client":
        SharedSystemClient._populate_data_from_system(system)
        instance = cls(tenant=tenant, database=database, settings=system.settings)
        return instance

    # endregion

    @override
    def get_user_identity(self) -> UserIdentity:
        try:
            return self._server.get_user_identity()
        except httpx.ConnectError:
            raise ValueError(
                "Could not connect to a Chroma server. Are you sure it is running?"
            )
        # Propagate ChromaErrors
        except ChromaError as e:
            raise e
        except Exception as e:
            raise ValueError(str(e))

    # region BaseAPI Methods
    # Note - we could do this in less verbose ways, but they break type checking
    @override
    def heartbeat(self) -> int:
        """Return the server time in nanoseconds since epoch."""
        return self._server.heartbeat()

    @override
    def list_collections(
        self, limit: Optional[int] = None, offset: Optional[int] = None
    ) -> Sequence[Collection]:
        """List collections for the current tenant and database, with pagination.

        Returns:
            Sequence[Collection]: Collection objects for the current tenant.
        """
        return [
            Collection(client=self._server, model=model)
            for model in self._server.list_collections(
                limit, offset, tenant=self.tenant, database=self.database
            )
        ]

    @override
    def count_collections(self) -> int:
        """Return the number of collections in the current database."""
        return self._server.count_collections(
            tenant=self.tenant, database=self.database
        )

    @override
    def create_collection(
        self,
        name: str,
        schema: Optional[Schema] = None,
        configuration: Optional[CreateCollectionConfiguration] = None,
        metadata: Optional[CollectionMetadata] = None,
        embedding_function: Optional[
            EmbeddingFunction[Embeddable]
        ] = DefaultEmbeddingFunction(),  # type: ignore
        data_loader: Optional[DataLoader[Loadable]] = None,
        get_or_create: bool = False,
    ) -> Collection:
        """Create a collection with optional configuration and metadata.

        If using a schema, do not provide `embedding_function`. Instead,
        provide the `embedding_function` as part of the schema.

        Args:
            name: Collection name.
            schema: Optional collection schema for indexes and encryption.
            configuration: Optional collection configuration.
            metadata: Optional collection metadata.
            embedding_function: Optional embedding function for the collection.
            data_loader: Optional data loader for documents with URIs.
            get_or_create: Whether to return an existing collection if present.

        Returns:
            Collection: The created collection.

        Raises:
            ValueError: If the embedding function conflicts with configuration.
        """
        if configuration is None:
            configuration = {}

        configuration_ef = configuration.get("embedding_function")

        validate_embedding_function_conflict_on_create(
            embedding_function, configuration_ef
        )

        # If ef provided in function params and collection config ef is None,
        # set the collection config ef to the function params
        if embedding_function is not None and configuration_ef is None:
            configuration["embedding_function"] = embedding_function

        model = self._server.create_collection(
            name=name,
            schema=schema,
            metadata=metadata,
            tenant=self.tenant,
            database=self.database,
            get_or_create=get_or_create,
            configuration=configuration,
        )
        return Collection(
            client=self._server,
            model=model,
            embedding_function=embedding_function,
            data_loader=data_loader,
        )

    @override
    def get_collection(
        self,
        name: str,
        embedding_function: Optional[
            EmbeddingFunction[Embeddable]
        ] = DefaultEmbeddingFunction(),  # type: ignore
        data_loader: Optional[DataLoader[Loadable]] = None,
    ) -> Collection:
        """Get a collection by name.

        Args:
            name: Collection name.
            embedding_function: Optional embedding function for the collection.
            data_loader: Optional data loader for documents with URIs.

        Returns:
            Collection: The requested collection.

        Raises:
            ValueError: If the embedding function conflicts with configuration.
        """
        model = self._server.get_collection(
            name=name,
            tenant=self.tenant,
            database=self.database,
        )
        persisted_ef_config = model.configuration_json.get("embedding_function")

        validate_embedding_function_conflict_on_get(
            embedding_function, persisted_ef_config
        )

        return Collection(
            client=self._server,
            model=model,
            embedding_function=embedding_function,
            data_loader=data_loader,
        )

    @override
    def get_or_create_collection(
        self,
        name: str,
        schema: Optional[Schema] = None,
        configuration: Optional[CreateCollectionConfiguration] = None,
        metadata: Optional[CollectionMetadata] = None,
        embedding_function: Optional[
            EmbeddingFunction[Embeddable]
        ] = DefaultEmbeddingFunction(),  # type: ignore
        data_loader: Optional[DataLoader[Loadable]] = None,
    ) -> Collection:
        """Get an existing collection or create a new one.

        If the collection does not exist, it will be created. If the collection
        already exists, the schema, configuration, and metadata arguments
        will be ignored.

        Args:
            name: Collection name.
            schema: Optional collection schema for indexes and encryption.
            configuration: Optional collection configuration.
            metadata: Optional collection metadata.
            embedding_function: Optional embedding function for the collection.
            data_loader: Optional data loader for URI-backed data.

        Returns:
            Collection: The existing or newly created collection.

        Raises:
            ValueError: If the embedding function does not match the collection's embedding function.
        """
        if configuration is None:
            configuration = {}

        configuration_ef = configuration.get("embedding_function")

        validate_embedding_function_conflict_on_create(
            embedding_function, configuration_ef
        )

        if embedding_function is not None and configuration_ef is None:
            configuration["embedding_function"] = embedding_function
        model = self._server.get_or_create_collection(
            name=name,
            schema=schema,
            metadata=metadata,
            tenant=self.tenant,
            database=self.database,
            configuration=configuration,
        )

        persisted_ef_config = model.configuration_json.get("embedding_function")

        validate_embedding_function_conflict_on_get(
            embedding_function, persisted_ef_config
        )

        return Collection(
            client=self._server,
            model=model,
            embedding_function=embedding_function,
            data_loader=data_loader,
        )

    @override
    def _modify(
        self,
        id: UUID,
        new_name: Optional[str] = None,
        new_metadata: Optional[CollectionMetadata] = None,
        new_configuration: Optional[UpdateCollectionConfiguration] = None,
    ) -> None:
        return self._server._modify(
            id=id,
            tenant=self.tenant,
            database=self.database,
            new_name=new_name,
            new_metadata=new_metadata,
            new_configuration=new_configuration,
        )

    @override
    def delete_collection(
        self,
        name: str,
    ) -> None:
        return self._server.delete_collection(
            name=name,
            tenant=self.tenant,
            database=self.database,
        )

    #
    # ITEM METHODS
    #

    @override
    def _add(
        self,
        ids: IDs,
        collection_id: UUID,
        embeddings: Embeddings,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
        uris: Optional[URIs] = None,
    ) -> bool:
        return self._server._add(
            ids=ids,
            tenant=self.tenant,
            database=self.database,
            collection_id=collection_id,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            uris=uris,
        )

    @override
    def _update(
        self,
        collection_id: UUID,
        ids: IDs,
        embeddings: Optional[Embeddings] = None,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
        uris: Optional[URIs] = None,
    ) -> bool:
        return self._server._update(
            collection_id=collection_id,
            tenant=self.tenant,
            database=self.database,
            ids=ids,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            uris=uris,
        )

    @override
    def _upsert(
        self,
        collection_id: UUID,
        ids: IDs,
        embeddings: Embeddings,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
        uris: Optional[URIs] = None,
    ) -> bool:
        return self._server._upsert(
            collection_id=collection_id,
            tenant=self.tenant,
            database=self.database,
            ids=ids,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            uris=uris,
        )

    @override
    def _count(self, collection_id: UUID) -> int:
        return self._server._count(
            collection_id=collection_id,
            tenant=self.tenant,
            database=self.database,
        )

    @override
    def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
        return self._server._peek(
            collection_id=collection_id,
            n=n,
            tenant=self.tenant,
            database=self.database,
        )

    @override
    def _get(
        self,
        collection_id: UUID,
        ids: Optional[IDs] = None,
        where: Optional[Where] = None,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        where_document: Optional[WhereDocument] = None,
        include: Include = IncludeMetadataDocuments,
    ) -> GetResult:
        return self._server._get(
            collection_id=collection_id,
            tenant=self.tenant,
            database=self.database,
            ids=ids,
            where=where,
            limit=limit,
            offset=offset,
            where_document=where_document,
            include=include,
        )

    def _delete(
        self,
        collection_id: UUID,
        ids: Optional[IDs],
        where: Optional[Where] = None,
        where_document: Optional[WhereDocument] = None,
    ) -> None:
        self._server._delete(
            collection_id=collection_id,
            tenant=self.tenant,
            database=self.database,
            ids=ids,
            where=where,
            where_document=where_document,
        )

    @override
    def _query(
        self,
        collection_id: UUID,
        query_embeddings: Embeddings,
        ids: Optional[IDs] = None,
        n_results: int = 10,
        where: Optional[Where] = None,
        where_document: Optional[WhereDocument] = None,
        include: Include = IncludeMetadataDocumentsDistances,
    ) -> QueryResult:
        return self._server._query(
            collection_id=collection_id,
            ids=ids,
            tenant=self.tenant,
            database=self.database,
            query_embeddings=query_embeddings,
            n_results=n_results,
            where=where,
            where_document=where_document,
            include=include,
        )

    @override
    def reset(self) -> bool:
        return self._server.reset()

    @override
    def get_version(self) -> str:
        return self._server.get_version()

    @override
    def get_settings(self) -> Settings:
        return self._server.get_settings()

    @override
    def get_max_batch_size(self) -> int:
        return self._server.get_max_batch_size()

    # endregion

    # region ClientAPI Methods

    @override
    def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
        self._validate_tenant_database(tenant=tenant, database=database)
        self.tenant = tenant
        self.database = database

    @override
    def set_database(self, database: str) -> None:
        self._validate_tenant_database(tenant=self.tenant, database=database)
        self.database = database

    def close(self) -> None:
        """Close the client and release all resources.

        This method decrements the reference count for the underlying System.
        When the last client using a shared System calls close(), the System
        is stopped and all resources (database connections, etc.) are released.

        This is particularly important for PersistentClient to avoid SQLite
        file locking issues.

        Note: If multiple clients share the same System (e.g., multiple PersistentClient
        instances with the same path), the System will only be stopped when the last
        client is closed. This allows safe use of context managers with multiple clients.

        Example:
            >>> client = chromadb.PersistentClient(path="./chroma_db")
            >>> # ... use client ...
            >>> client.close()

            Or using context manager:
            >>> with chromadb.PersistentClient(path="./chroma_db") as client:
            ...     # ... use client ...
        """
        # Make close() idempotent - a second call is a safe no-op
        if self._closed:
            return
        self._closed = True

        # Release the internal admin client's reference first, since it also
        # incremented the refcount for the shared system on creation.
        if hasattr(self, "_admin_client"):
            SharedSystemClient._release_system(self._admin_client._identifier)

        # Release our own reference; stops system if this was the last client
        SharedSystemClient._release_system(self._identifier)

    def __enter__(self) -> "Client":
        """Context manager entry."""
        return self

    def __exit__(
        self,
        exc_type: Optional[type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> None:
        """Context manager exit."""
        self.close()

    def _validate_tenant_database(self, tenant: str, database: str) -> None:
        try:
            self._admin_client.get_tenant(name=tenant)
        except httpx.ConnectError:
            raise ValueError(
                "Could not connect to a Chroma server. Are you sure it is running?"
            )
        # Propagate ChromaErrors
        except ChromaError as e:
            raise e
        except Exception:
            raise ValueError(
                f"Could not connect to tenant {tenant}. Are you sure it exists?"
            )

        try:
            self._admin_client.get_database(name=database, tenant=tenant)
        except httpx.ConnectError:
            raise ValueError(
                "Could not connect to a Chroma server. Are you sure it is running?"
            )

    # endregion


class AdminClient(SharedSystemClient, AdminAPI):
    """Admin client for managing tenants and databases."""

    _server: ServerAPI

    def __init__(self, settings: Settings = Settings()) -> None:
        super().__init__(settings)
        self._server = self._system.instance(ServerAPI)

    @override
    def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
        """Create a database in a tenant.

        Args:
            name: Database name.
            tenant: Tenant that owns the database.
        """
        return self._server.create_database(name=name, tenant=tenant)

    @override
    def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
        """Get a database by name.

        Args:
            name: Database name.
            tenant: Tenant that owns the database.

        Returns:
            Database: The database record.
        """
        return self._server.get_database(name=name, tenant=tenant)

    @override
    def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
        """Delete a database by name.

        Args:
            name: Database name.
            tenant: Tenant that owns the database.
        """
        return self._server.delete_database(name=name, tenant=tenant)

    @override
    def list_databases(
        self,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        tenant: str = DEFAULT_TENANT,
    ) -> Sequence[Database]:
        return self._server.list_databases(limit, offset, tenant=tenant)

    @override
    def create_tenant(self, name: str) -> None:
        return self._server.create_tenant(name=name)

    @override
    def get_tenant(self, name: str) -> Tenant:
        return self._server.get_tenant(name=name)

    @classmethod
    @override
    def from_system(
        cls,
        system: System,
    ) -> "AdminClient":
        SharedSystemClient._populate_data_from_system(system)
        instance = cls(settings=system.settings)
        return instance
