# ruff: noqa: S110
from __future__ import annotations

import asyncio
import threading
import typing

T = typing.TypeVar("T")

# Used by `wait_for_not_none` to narrow `ValueWatcher[X | None]` to `X`.
S = typing.TypeVar("S")


class ValueWatcher(typing.Generic[T]):
    """
    Thread-safe observable value with async watchers.

    Watchers can await value changes via methods like `wait_for` and
    `wait_for_change`. Alternatively, they can add callbacks via `on_change` and
    `on_value`.

    Any thread can set `.value`, and the watcher will react accordingly.
    """

    def __init__(
        self,
        initial_value: T,
        *,
        on_change: typing.Callable[[T, T], None] | None = None,
    ) -> None:
        """
        Args:
            initial_value: The initial value.
            on_change: Called when the value changes. Good for debug logging.
        """

        self._lock = threading.Lock()
        self._on_changes: list[typing.Callable[[T, T], None]] = []
        if on_change:
            self._on_changes.append(on_change)

        # Every watcher gets its own (loop, queue) pair. Storing the loop lets
        # the setter use `call_soon_threadsafe` for cross-thread notification.
        # Queue items are (old, new) tuples.
        self._watch_queues: list[
            tuple[asyncio.AbstractEventLoop, asyncio.Queue[tuple[T, T]]]
        ] = []

        # Hold references to fire-and-forget tasks to prevent GC.
        self._background_tasks: set[asyncio.Task[T]] = set()

        self._value = initial_value

    @property
    def value(self) -> T:
        with self._lock:
            return self._value

    @value.setter
    def value(self, new_value: T) -> None:
        with self._lock:
            if new_value == self._value:
                return

            old_value = self._value
            self._value = new_value

            # Snapshot lists under lock to avoid iteration issues
            queues = list(self._watch_queues)
            callbacks = list(self._on_changes)

        # Notify all watchers outside the lock to avoid deadlock.
        for loop, queue in queues:
            try:
                # `call_soon_threadsafe` wakes the target loop's selector
                # immediately. A plain `put_nowait` wouldn't poke the self-pipe,
                # so a cross-thread watcher could stall until something else
                # wakes its loop.
                #
                # In other words, without `call_soon_threadsafe`, a watcher
                # could get the changed value notification long after the value
                # actually changed.
                loop.call_soon_threadsafe(
                    queue.put_nowait, (old_value, new_value)
                )
            except RuntimeError:
                # Target event loop is closed.
                pass

        for on_change in callbacks:
            try:
                on_change(old_value, new_value)
            except Exception:
                # Suppress exceptions from callbacks so one failure doesn't skip
                # the rest.
                pass

    def set_if(
        self,
        new_value: T,
        condition: typing.Callable[[T], bool],
    ) -> bool:
        """
        Atomically set the value only if the current value satisfies the
        condition. Returns True if the value was set.
        """

        with self._lock:
            if not condition(self._value):
                return False

            if new_value == self._value:
                return True

            old_value = self._value
            self._value = new_value

            queues = list(self._watch_queues)
            callbacks = list(self._on_changes)

        for loop, queue in queues:
            try:
                loop.call_soon_threadsafe(
                    queue.put_nowait, (old_value, new_value)
                )
            except RuntimeError:
                pass

        for on_change in callbacks:
            try:
                on_change(old_value, new_value)
            except Exception:
                pass

        return True

    def on_change(self, callback: typing.Callable[[T, T], None]) -> None:
        """
        Add a callback that's called when the value changes.

        Args:
            callback: Called with (old_value, new_value) on each change.
        """

        with self._lock:
            self._on_changes.append(callback)

    def on_value(self, value: T, callback: typing.Callable[[], None]) -> None:
        """
        One-shot callback for when the value equals `value`. Requires a
        running event loop (internally spawns a background task).

        Args:
            value: The value to wait for.
            callback: Called when the internal value equals `value`.
        """

        task = asyncio.create_task(self.wait_for(value))
        self._background_tasks.add(task)

        def _done(t: asyncio.Task[T]) -> None:
            self._background_tasks.discard(t)
            if not t.cancelled() and t.exception() is None:
                callback()

        task.add_done_callback(_done)

    async def wait_for(
        self,
        value: T,
        *,
        immediate: bool = True,
        timeout: float | None = None,
    ) -> T:
        """
        Wait for the internal value to equal the given value.

        Args:
            value: Return when the internal value is equal to this.
            immediate: If True and the internal value is already equal to the given value, return immediately. Defaults to True.
            timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.
        """

        return await self._wait_for_condition(
            lambda v: v == value,
            immediate=immediate,
            timeout=timeout,
        )

    async def wait_for_not(
        self,
        value: T,
        *,
        immediate: bool = True,
        timeout: float | None = None,
    ) -> T:
        """
        Wait for the internal value to not equal the given value.

        Args:
            value: Return when the internal value is not equal to this.
            immediate: If True and the internal value is already not equal to the given value, return immediately. Defaults to True.
            timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.
        """

        return await self._wait_for_condition(
            lambda v: v != value,
            immediate=immediate,
            timeout=timeout,
        )

    async def wait_for_not_none(
        self: ValueWatcher[S | None],
        *,
        immediate: bool = True,
        timeout: float | None = None,
    ) -> S:
        """
        Wait for the internal value to be not None.

        Args:
            immediate: If True and the internal value is already not None, return immediately. Defaults to True.
            timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.
        """

        result = await self._wait_for_condition(
            lambda v: v is not None,
            immediate=immediate,
            timeout=timeout,
        )
        if result is None:
            raise AssertionError("unreachable")
        return result

    async def _wait_for_condition(
        self,
        condition: typing.Callable[[T], bool],
        *,
        immediate: bool = True,
        timeout: float | None = None,
    ) -> T:
        """
        Wait until `condition(current_value)` is true, then return the
        matching value. Handles the TOCTOU gap between checking the current
        value and subscribing to the change queue.
        """

        # Fast path: no task needed if the value already matches.
        if immediate:
            # Read once to avoid a TOCTOU race between check and return.
            current = self.value
            if condition(current):
                return current

        async def _wait() -> T:
            with self._watch() as queue:
                # Re-check after queue registration to close the gap
                # between the fast path above and the queue being live.
                if immediate:
                    # Read once to avoid a TOCTOU race between check and return.
                    current = self.value
                    if condition(current):
                        return current

                while True:
                    _, new = await queue.get()
                    if condition(new):
                        return new

        return await asyncio.wait_for(_wait(), timeout=timeout)

    async def wait_for_change(
        self,
        *,
        timeout: float | None = None,
    ) -> T:
        """
        Wait for the internal value to change.

        Args:
            timeout: Seconds to wait before raising `asyncio.TimeoutError`. None means wait forever.
        """

        async def _wait() -> T:
            with self._watch() as queue:
                _, new = await queue.get()
                return new

        return await asyncio.wait_for(_wait(), timeout=timeout)

    def _watch(self) -> _WatchContextManager[T]:
        """
        Watch for all changes to the value. This method returns a context
        manager so it must be used in a `with` statement.

        Its return value is a queue that yields tuples of the old and new
        values.
        """

        loop = asyncio.get_running_loop()
        queue = asyncio.Queue[tuple[T, T]]()
        with self._lock:
            self._watch_queues.append((loop, queue))

        return _WatchContextManager(
            on_exit=lambda: self._remove_queue(queue),
            queue=queue,
        )

    def _remove_queue(self, queue: asyncio.Queue[tuple[T, T]]) -> None:
        """
        Remove a queue from the watch list in a thread-safe manner.
        """

        with self._lock:
            self._watch_queues = [
                entry for entry in self._watch_queues if entry[1] is not queue
            ]


class _WatchContextManager(typing.Generic[T]):
    """
    Context manager that's used to automatically delete a queue when it's no
    longer being watched.

    Returns a queue that yields tuples of the old and new values.
    """

    def __init__(
        self,
        on_exit: typing.Callable[[], None],
        queue: asyncio.Queue[tuple[T, T]],
    ) -> None:
        self._on_exit = on_exit
        self._queue = queue

    def __enter__(self) -> asyncio.Queue[tuple[T, T]]:
        # IMPORTANT: Do not return an async generator. That can lead to "Task
        # was destroyed but it is pending!" warnings when the event loop closes.
        return self._queue

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        traceback: object,
    ) -> None:
        self._on_exit()
