import requests
import json
import os
import warnings
from typing import Literal, Sequence, Optional, List, Union, Generator
from .utils import get_max_items_from_list
from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError

class TavilyClient:
    """
    Tavily API client class.
    """

    def __init__(self, api_key: Optional[str] = None, proxies: Optional[dict[str, str]] = None, api_base_url: Optional[str] = None, client_source: Optional[str] = None, project_id: Optional[str] = None):
        if api_key is None:
            api_key = os.getenv("TAVILY_API_KEY")

        if not api_key:
            raise MissingAPIKeyError()

        resolved_proxies = {
            "http": proxies.get("http") if proxies else os.getenv("TAVILY_HTTP_PROXY"),
            "https": proxies.get("https") if proxies else os.getenv("TAVILY_HTTPS_PROXY"),
        }

        resolved_proxies = {k: v for k, v in resolved_proxies.items() if v} or None
        tavily_project = project_id or os.getenv("TAVILY_PROJECT")
        
        self.base_url = api_base_url or "https://api.tavily.com"
        self.api_key = api_key
        self.proxies = resolved_proxies
        
        self.headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}",
            "X-Client-Source": client_source or "tavily-python",
            **({"X-Project-ID": tavily_project} if tavily_project else {})
        }

        self.session = requests.Session()
        self.session.headers.update(self.headers)
        if self.proxies:
            self.session.proxies.update(self.proxies)

    def close(self):
        """Close the session and release resources."""
        self.session.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _search(self,
                query: str,
                search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None,
                topic: Literal["general", "news", "finance"] = None,
                time_range: Literal["day", "week", "month", "year"] = None,
                start_date: str = None,
                end_date: str = None,
                days: int = None,
                max_results: int = None,
                include_domains: Sequence[str] = None,
                exclude_domains: Sequence[str] = None,
                include_answer: Union[bool, Literal["basic", "advanced"]] = None,
                include_raw_content: Union[bool, Literal["markdown", "text"]] = None,
                include_images: bool = None,
                timeout: float = 60,
                country: str = None,
                auto_parameters: bool = None,
                include_favicon: bool = None,
                include_usage: bool = None,
                exact_match: bool = None,
                **kwargs
                ) -> dict:
        """
        Internal search method to send the request to the API.
        """

        data = {
            "query": query,
            "search_depth": search_depth,
            "topic": topic,
            "time_range": time_range,
            "start_date": start_date,
            "end_date": end_date,
            "days": days,
            "include_answer": include_answer,
            "include_raw_content": include_raw_content,
            "max_results": max_results,
            "include_domains": include_domains,
            "exclude_domains": exclude_domains,
            "include_images": include_images,
            "country": country,
            "auto_parameters": auto_parameters,
            "include_favicon": include_favicon,
            "include_usage": include_usage,
            "exact_match": exact_match,
        }

        data = {k: v for k, v in data.items() if v is not None}

        if kwargs:
            data.update(kwargs)

        timeout = min(timeout, 120)
        url = self.base_url + "/search"
        payload = json.dumps(data)

        try:
            response = self.session.post(url, data=payload, timeout=timeout)
        except requests.exceptions.Timeout:
            raise TimeoutError(timeout)

        if response.status_code == 200:
            return response.json()
        else:
            detail = ""
            try:
                detail = response.json().get("detail", {}).get("error", None)
            except Exception:
                pass

            if response.status_code == 429:
                raise UsageLimitExceededError(detail)
            elif response.status_code in [403, 432, 433]:
                raise ForbiddenError(detail)
            elif response.status_code == 401:
                raise InvalidAPIKeyError(detail)
            elif response.status_code == 400:
                raise BadRequestError(detail)
            else:
                raise response.raise_for_status()


    def search(self,
               query: str,
               search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = None,
               topic: Literal["general", "news", "finance" ] = None,
               time_range: Literal["day", "week", "month", "year"] = None,
               start_date: str = None,
               end_date: str = None,
               days: int = None,
               max_results: int = None,
               include_domains: Sequence[str] = None,
               exclude_domains: Sequence[str] = None,
               include_answer: Union[bool, Literal["basic", "advanced"]] = None,
               include_raw_content: Union[bool, Literal["markdown", "text"]] = None,
               include_images: bool = None,
               timeout: float = 60,
               country: str = None,
               auto_parameters: bool = None,
               include_favicon: bool = None,
               include_usage: bool = None,
               exact_match: bool = None,
               **kwargs,  # Accept custom arguments
               ) -> dict:
        """
        Combined search method.
        """

        response_dict = self._search(query,
                                     search_depth=search_depth,
                                     topic=topic,
                                     time_range=time_range,
                                     start_date=start_date,
                                     end_date=end_date,
                                     days=days,
                                     max_results=max_results,
                                     include_domains=include_domains,
                                     exclude_domains=exclude_domains,
                                     include_answer=include_answer,
                                     include_raw_content=include_raw_content,
                                     include_images=include_images,
                                     timeout=timeout,
                                     country=country,
                                     auto_parameters=auto_parameters,
                                     include_favicon=include_favicon,
                                     include_usage=include_usage,
                                     exact_match=exact_match,
                                     **kwargs)
        response_dict.setdefault("results", [])
        return response_dict

    def _extract(self,
                 urls: Union[List[str], str],
                 include_images: bool = None,
                 extract_depth: Literal["basic", "advanced"] = None,
                 format: Literal["markdown", "text"] = None,
                 timeout: float = 30,
                 include_favicon: bool = None,
                 include_usage: bool = None,
                 query: str = None,
                 chunks_per_source: int = None,
                 **kwargs
                 ) -> dict:
        """
        Internal extract method to send the request to the API. 
        """
        data = {
            "urls": urls,
            "include_images": include_images,
            "extract_depth": extract_depth,
            "format": format,
            "timeout": timeout,
            "include_favicon": include_favicon,
            "include_usage": include_usage,
            "query": query,
            "chunks_per_source": chunks_per_source,
        }

        data = {k: v for k, v in data.items() if v is not None}

        if kwargs:
            data.update(kwargs)

        try:
            response = self.session.post(self.base_url + "/extract", data=json.dumps(data), timeout=timeout)
        except requests.exceptions.Timeout:
            raise TimeoutError(timeout)

        if response.status_code == 200:
            return response.json()
        else:
            detail = ""
            try:
                detail = response.json().get("detail", {}).get("error", None)
            except Exception:
                pass

            if response.status_code == 429:
                raise UsageLimitExceededError(detail)
            elif response.status_code in [403, 432, 433]:
                raise ForbiddenError(detail)
            elif response.status_code == 401:
                raise InvalidAPIKeyError(detail)
            elif response.status_code == 400:
                raise BadRequestError(detail)
            else:
                raise response.raise_for_status()

    def extract(self,
                urls: Union[List[str], str],  # Accept a list of URLs or a single URL
                include_images: bool = None,
                extract_depth: Literal["basic", "advanced"] = None,
                format: Literal["markdown", "text"] = None,
                timeout: float = 30,
                include_favicon: bool = None,
                include_usage: bool = None,
                query: str = None,
                chunks_per_source: int = None,
                **kwargs,  # Accept custom arguments
                ) -> dict:
        """
        Combined extract method.
        """
        response_dict = self._extract(urls,
                                      include_images,
                                      extract_depth,
                                      format,
                                      timeout,
                                      include_favicon=include_favicon,
                                      include_usage=include_usage,
                                      query=query,
                                      chunks_per_source=chunks_per_source,
                                      **kwargs)
        response_dict.setdefault("results", [])
        response_dict.setdefault("failed_results", [])
        return response_dict

    def _crawl(self,
            url: str,
            max_depth: int = None,
            max_breadth: int = None,
            limit: int = None,
            instructions: str = None,
            select_paths: Sequence[str] = None,
            select_domains: Sequence[str] = None,
            exclude_paths: Sequence[str] = None,
            exclude_domains: Sequence[str] = None,
            allow_external: bool = None,
            include_images: bool = None,
            extract_depth: Literal["basic", "advanced"] = None,
            format: Literal["markdown", "text"] = None,
            timeout: float = 150,
            include_favicon: bool = None,
            include_usage: bool = None,
            chunks_per_source: int = None,
            **kwargs
            ) -> dict:
        """
        Internal crawl method to send the request to the API.
        include_favicon: If True, include the favicon in the crawl results.
        """
        data = {
            "url": url,
            "max_depth": max_depth,
            "max_breadth": max_breadth,
            "limit": limit,
            "instructions": instructions,
            "select_paths": select_paths,
            "select_domains": select_domains,
            "exclude_paths": exclude_paths,
            "exclude_domains": exclude_domains,
            "allow_external": allow_external,
            "include_images": include_images,
            "extract_depth": extract_depth,
            "format": format,
            "timeout": timeout,
            "include_favicon": include_favicon,
            "include_usage": include_usage,
            "chunks_per_source": chunks_per_source,
        }

        if kwargs:
            data.update(kwargs)
        
        data = {k: v for k, v in data.items() if v is not None}

        try:
            response = self.session.post(self.base_url + "/crawl", data=json.dumps(data), timeout=timeout)
        except requests.exceptions.Timeout:
            raise TimeoutError(timeout)

        if response.status_code == 200:
            return response.json()
        else:
            detail = ""
            try:
                detail = response.json().get("detail", {}).get("error", None)
            except Exception:
                pass

            if response.status_code == 429:
                raise UsageLimitExceededError(detail)
            elif response.status_code in [403, 432, 433]:
                raise ForbiddenError(detail)
            elif response.status_code == 401:
                raise InvalidAPIKeyError(detail)
            elif response.status_code == 400:
                raise BadRequestError(detail)
            else:
                raise response.raise_for_status()

    def crawl(self,
              url: str,
              max_depth: int = None,
              max_breadth: int = None,
              limit: int = None,
              instructions: str = None,
              select_paths: Sequence[str] = None,
              select_domains: Sequence[str] = None,
              exclude_paths: Sequence[str] = None,
              exclude_domains: Sequence[str] = None,
              allow_external: bool = None,
              include_images: bool = None,
              extract_depth: Literal["basic", "advanced"] = None,
              format: Literal["markdown", "text"] = None,
              timeout: float = 150,
              include_favicon: bool = None,
              include_usage: bool = None,
              chunks_per_source: int = None,
              **kwargs
              ) -> dict:
        """
        Combined crawl method.
        include_favicon: If True, include the favicon in the crawl results.
        """
        return self._crawl(url,
                           max_depth=max_depth,
                           max_breadth=max_breadth,
                           limit=limit,
                           instructions=instructions,
                           select_paths=select_paths,
                           select_domains=select_domains,
                           exclude_paths=exclude_paths,
                           exclude_domains=exclude_domains,
                           allow_external=allow_external,
                           include_images=include_images,
                           extract_depth=extract_depth,
                           format=format,
                           timeout=timeout,
                           include_favicon=include_favicon,
                           include_usage=include_usage,
                           chunks_per_source=chunks_per_source,
                           **kwargs)
    
    def _map(self,
            url: str,
            max_depth: int = None,
            max_breadth: int = None,
            limit: int = None,
            instructions: str = None,
            select_paths: Sequence[str] = None,
            select_domains: Sequence[str] = None,
            exclude_paths: Sequence[str] = None,
            exclude_domains: Sequence[str] = None,
            allow_external: bool = None,
            include_images: bool = None,
            timeout: float = 150,
            include_usage: bool = None,
            **kwargs
            ) -> dict:
        """
        Internal map method to send the request to the API.
        """
        data = {
            "url": url,
            "max_depth": max_depth,
            "max_breadth": max_breadth,
            "limit": limit,
            "instructions": instructions,
            "select_paths": select_paths,
            "select_domains": select_domains,
            "exclude_paths": exclude_paths,
            "exclude_domains": exclude_domains,
            "allow_external": allow_external,
            "include_images": include_images,
            "timeout": timeout,
            "include_usage": include_usage,
        }

        if kwargs:
            data.update(kwargs)
        
        data = {k: v for k, v in data.items() if v is not None}

        try:
            response = self.session.post(self.base_url + "/map", data=json.dumps(data), timeout=timeout)
        except requests.exceptions.Timeout:
            raise TimeoutError(timeout)

        if response.status_code == 200:
            return response.json()
        else:
            detail = ""
            try:
                detail = response.json().get("detail", {}).get("error", None)
            except Exception:
                pass

            if response.status_code == 429:
                raise UsageLimitExceededError(detail)
            elif response.status_code in [403, 432, 433]:
                raise ForbiddenError(detail)
            elif response.status_code == 401:
                raise InvalidAPIKeyError(detail)
            elif response.status_code == 400:
                raise BadRequestError(detail)
            else:
                raise response.raise_for_status()

    def map(self,
              url: str,
              max_depth: int = None,
              max_breadth: int = None,
              limit: int = None,
              instructions: str = None,
              select_paths: Sequence[str] = None,
              select_domains: Sequence[str] = None,
              exclude_paths: Sequence[str] = None,
              exclude_domains: Sequence[str] = None,
              allow_external: bool = None,
              include_images: bool = None,
              timeout: float = 150,
              include_usage: bool = None,
              **kwargs
              ) -> dict:
        """
        Combined map method.
        
        """
        return self._map(url,
                         max_depth=max_depth,
                         max_breadth=max_breadth,
                         limit=limit,
                         instructions=instructions,
                         select_paths=select_paths,
                         select_domains=select_domains,
                         exclude_paths=exclude_paths,
                         exclude_domains=exclude_domains,
                         allow_external=allow_external,
                         include_images=include_images,
                         timeout=timeout,
                         include_usage=include_usage,
                         **kwargs)

    def get_search_context(self,
                           query: str,
                           search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = "basic",
                           topic: Literal["general", "news", "finance"] = "general",
                           days: int = 7,
                           max_results: int = 5,
                           include_domains: Sequence[str] = None,
                           exclude_domains: Sequence[str] = None,
                           max_tokens: int = 4000,
                           timeout: float = 60,
                           country: str = None,
                           include_favicon: bool = None,
                           **kwargs,  # Accept custom arguments
                           ) -> str:
        """
        Get the search context for a query. Useful for getting only related content from retrieved websites
        without having to deal with context extraction and limitation yourself.

        max_tokens: The maximum number of tokens to return (based on openai token compute). Defaults to 4000.

        Returns a string of JSON containing the search context up to context limit.
        """
        warnings.warn("get_search_context is deprecated and will be removed in future versions.",
                      DeprecationWarning, stacklevel=2)

        response_dict = self._search(query,
                                     search_depth=search_depth,
                                     topic=topic,
                                     days=days,
                                     max_results=max_results,
                                     include_domains=include_domains,
                                     exclude_domains=exclude_domains,
                                     include_answer=False,
                                     include_raw_content=False,
                                     include_images=False,
                                     timeout=timeout,
                                     country=country,
                                     include_favicon=include_favicon,
                                     **kwargs,
                                     )
        sources = response_dict.get("results", [])
        context = [{"url": source["url"], "content": source["content"]}
                   for source in sources]
        return json.dumps(get_max_items_from_list(context, max_tokens))

    def qna_search(self,
                   query: str,
                   search_depth: Literal["basic", "advanced", "fast", "ultra-fast"] = "advanced",
                   topic: Literal["general", "news", "finance"] = "general",
                   days: int = 7,
                   max_results: int = 5,
                   include_domains: Sequence[str] = None,
                   exclude_domains: Sequence[str] = None,
                   timeout: float = 60,
                   country: str = None,
                   include_favicon: bool = None,
                   **kwargs,  # Accept custom arguments
                   ) -> str:
        """
        Q&A search method. Search depth is advanced by default to get the best answer.
        """
        warnings.warn("qna_search is deprecated and will be removed in future versions.",
                      DeprecationWarning, stacklevel=2)
        response_dict = self._search(query,
                                     search_depth=search_depth,
                                     topic=topic,
                                     days=days,
                                     max_results=max_results,
                                     include_domains=include_domains,
                                     exclude_domains=exclude_domains,
                                     include_raw_content=False,
                                     include_images=False,
                                     include_answer=True,
                                     timeout=timeout,
                                     country=country,
                                     include_favicon=include_favicon,
                                     **kwargs,
                                     )
        return response_dict.get("answer", "")

    def _research(self,
                  input: str,
                  model: Literal["mini", "pro", "auto"] = None,
                  output_schema: dict = None,
                  stream: bool = False,
                  citation_format: Literal["numbered", "mla", "apa", "chicago"] = "numbered",
                  timeout: Optional[float] = None,
                  **kwargs
                  ) -> Union[dict, Generator[bytes, None, None]]:
        """
        Internal research method to send the request to the API.
        """
        data = {
            "input": input,
            "model": model,
            "output_schema": output_schema,
            "stream": stream,
            "citation_format": citation_format,
        }

        data = {k: v for k, v in data.items() if v is not None}

        if kwargs:
            data.update(kwargs)

        if stream:
            try:
                response = self.session.post(
                    self.base_url + "/research",
                    data=json.dumps(data),
                    timeout=timeout,
                    stream=True
                )
            except requests.exceptions.Timeout:
                raise TimeoutError(timeout)

            if response.status_code != 200:
                detail = ""
                try:
                    detail = response.json().get("detail", {}).get("error", None)
                except Exception:
                    pass

                if response.status_code == 429:
                    raise UsageLimitExceededError(detail)
                elif response.status_code in [403, 432, 433]:
                    raise ForbiddenError(detail)
                elif response.status_code == 401:
                    raise InvalidAPIKeyError(detail)
                elif response.status_code == 400:
                    raise BadRequestError(detail)
                else:
                    raise response.raise_for_status()

            def stream_generator() -> Generator[bytes, None, None]:
                try:
                    for chunk in response.iter_content(chunk_size=None):
                        if chunk:
                            yield chunk
                finally:
                    response.close()

            return stream_generator()
        else:
            try:
                response = self.session.post(
                    self.base_url + "/research",
                    data=json.dumps(data),
                    timeout=timeout
                )
            except requests.exceptions.Timeout:
                raise TimeoutError(timeout)

            if response.status_code == 200:
                return response.json()
            else:
                detail = ""
                try:
                    detail = response.json().get("detail", {}).get("error", None)
                except Exception:
                    pass

                if response.status_code == 429:
                    raise UsageLimitExceededError(detail)
                elif response.status_code in [403, 432, 433]:
                    raise ForbiddenError(detail)
                elif response.status_code == 401:
                    raise InvalidAPIKeyError(detail)
                elif response.status_code == 400:
                    raise BadRequestError(detail)
                else:
                    raise response.raise_for_status()

    def research(self,
                 input: str,
                 model: Literal["mini", "pro", "auto"] = None,
                 output_schema: dict = None,
                 stream: bool = False,
                 citation_format: Literal["numbered", "mla", "apa", "chicago"] = "numbered",
                 timeout: Optional[float] = None,
                 **kwargs
                 ) -> Union[dict, Generator[bytes, None, None]]:
        """
        Research method to create a research task.
        
        Args:
            input: The research task or question to investigate (required).
            model: The model used by the research agent - must be either 'mini', 'pro', or 'auto'.
            output_schema: Schema for the 'structured_output' response format (JSON Schema dict).
            stream: Whether to stream the research task.
            citation_format: Citation format - must be either 'numbered', 'mla', 'apa', or 'chicago'.
            timeout: Optional HTTP request timeout in seconds. 
            **kwargs: Additional custom arguments.
        
        Returns:
            dict: Response containing request_id, created_at, status, input, and model.
        """

        return self._research(
            input=input,
            model=model,
            output_schema=output_schema,
            stream=stream,
            citation_format=citation_format,
            timeout=timeout,
            **kwargs
        )

    def get_research(self,
                     request_id: str
                     ) -> dict:
        """
        Get research results by request_id.
        
        Args:
            request_id: The research request ID.
        
        Returns:
            dict: Research response containing request_id, created_at, completed_at, status, content, and sources.
        """
        try:
            response = self.session.get(self.base_url + f"/research/{request_id}")
        except Exception as e:
            raise Exception(f"Error getting research: {e}")

        if response.status_code in (200, 202):
            return response.json()
        else:
            detail = ""
            try:
                detail = response.json().get("detail", {}).get("error", None)
            except Exception:
                pass

            if response.status_code == 429:
                raise UsageLimitExceededError(detail)
            elif response.status_code in [403, 432, 433]:
                raise ForbiddenError(detail)
            elif response.status_code == 401:
                raise InvalidAPIKeyError(detail)
            elif response.status_code == 400:
                raise BadRequestError(detail)
            else:
                raise response.raise_for_status()


class Client(TavilyClient):
    """
    Tavily API client class.

    WARNING! This class is deprecated. Please use TavilyClient instead.
    """

    def __init__(self, kwargs):
        warnings.warn("Client is deprecated, please use TavilyClient instead",
                      DeprecationWarning, stacklevel=2)
        super().__init__(kwargs)
