# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause


from cpython cimport Py_INCREF, PyObject, PyTypeObject

from libc.math cimport fabsf
from libc.stdlib cimport free
from libc.string cimport memcpy
from libc.stdio cimport printf
from libc.stdint cimport SIZE_MAX

from sklearn.tree._utils cimport safe_realloc

import numpy as np
cimport numpy as cnp
cnp.import_array()

cdef extern from "numpy/arrayobject.h":
    object PyArray_NewFromDescr(PyTypeObject* subtype, cnp.dtype descr,
                                int nd, cnp.npy_intp* dims,
                                cnp.npy_intp* strides,
                                void* data, int flags, object obj)
    int PyArray_SetBaseObject(cnp.ndarray arr, PyObject* obj)

# Build the corresponding numpy dtype for Cell.
# This works by casting `dummy` to an array of Cell of length 1, which numpy
# can construct a `dtype`-object for. See https://stackoverflow.com/q/62448946
# for a more detailed explanation.
cdef Cell dummy
CELL_DTYPE = np.asarray(<Cell[:1]>(&dummy)).dtype

assert CELL_DTYPE.itemsize == sizeof(Cell)

cdef const float EPSILON = 1e-6


cdef class _QuadTree:
    """Array-based representation of a QuadTree.

    This class is currently working for indexing 2D data (regular QuadTree) and
    for indexing 3D data (OcTree). It is planned to split the 2 implementations
    using `Cython.Tempita` to save some memory for QuadTree.

    Note that this code is currently internally used only by the Barnes-Hut
    method in `sklearn.manifold.TSNE`. It is planned to be refactored and
    generalized in the future to be compatible with nearest neighbors API of
    `sklearn.neighbors` with 2D and 3D data.
    """
    def __cinit__(self, int n_dimensions, int verbose):
        """Constructor."""
        # Parameters of the tree
        self.n_dimensions = n_dimensions
        self.verbose = verbose
        self.n_cells_per_cell = <int> (2 ** self.n_dimensions)

        # Inner structures
        self.max_depth = 0
        self.cell_count = 0
        self.capacity = 0
        self.n_points = 0
        self.cells = NULL

    def __dealloc__(self):
        """Destructor."""
        # Free all inner structures
        free(self.cells)

    @property
    def cumulative_size(self):
        cdef Cell[:] cell_mem_view = self._get_cell_ndarray()
        return cell_mem_view.base['cumulative_size'][:self.cell_count]

    @property
    def leafs(self):
        cdef Cell[:] cell_mem_view = self._get_cell_ndarray()
        return cell_mem_view.base['is_leaf'][:self.cell_count]

    def build_tree(self, X):
        """Build a tree from an array of points X."""
        cdef:
            int i
            float32_t[3] pt
            float32_t[3] min_bounds, max_bounds

        # validate X and prepare for query
        # X = check_array(X, dtype=float32_t, order='C')
        n_samples = X.shape[0]

        capacity = 100
        self._resize(capacity)
        m = np.min(X, axis=0)
        M = np.max(X, axis=0)
        # Scale the maximum to get all points strictly in the tree bounding box
        # The 3 bounds are for positive, negative and small values
        M = np.maximum(M * (1. + 1e-3 * np.sign(M)), M + 1e-3)
        for i in range(self.n_dimensions):
            min_bounds[i] = m[i]
            max_bounds[i] = M[i]

            if self.verbose > 10:
                printf("[QuadTree] bounding box axis %i : [%f, %f]\n",
                       i, min_bounds[i], max_bounds[i])

        # Create the initial node with boundaries from the dataset
        self._init_root(min_bounds, max_bounds)

        for i in range(n_samples):
            for j in range(self.n_dimensions):
                pt[j] = X[i, j]
            self.insert_point(pt, i)

        # Shrink the cells array to reduce memory usage
        self._resize(capacity=self.cell_count)

    cdef int insert_point(self, float32_t[3] point, intp_t point_index,
                          intp_t cell_id=0) except -1 nogil:
        """Insert a point in the QuadTree."""
        cdef int ax
        cdef intp_t selected_child
        cdef Cell* cell = &self.cells[cell_id]
        cdef intp_t n_point = cell.cumulative_size

        if self.verbose > 10:
            printf("[QuadTree] Inserting depth %li\n", cell.depth)

        # Assert that the point is in the right range
        if DEBUGFLAG:
            self._check_point_in_cell(point, cell)

        # If the cell is an empty leaf, insert the point in it
        if cell.cumulative_size == 0:
            cell.cumulative_size = 1
            self.n_points += 1
            for i in range(self.n_dimensions):
                cell.barycenter[i] = point[i]
            cell.point_index = point_index
            if self.verbose > 10:
                printf("[QuadTree] inserted point %li in cell %li\n",
                       point_index, cell_id)
            return cell_id

        # If the cell is not a leaf, update cell internals and
        # recurse in selected child
        if not cell.is_leaf:
            for ax in range(self.n_dimensions):
                # barycenter update using a weighted mean
                cell.barycenter[ax] = (
                    n_point * cell.barycenter[ax] + point[ax]) / (n_point + 1)

            # Increase the size of the subtree starting from this cell
            cell.cumulative_size += 1

            # Insert child in the correct subtree
            selected_child = self._select_child(point, cell)
            if self.verbose > 49:
                printf("[QuadTree] selected child %li\n", selected_child)
            if selected_child == -1:
                self.n_points += 1
                return self._insert_point_in_new_child(point, cell, point_index)
            return self.insert_point(point, point_index, selected_child)

        # Finally, if the cell is a leaf with a point already inserted,
        # split the cell in n_cells_per_cell if the point is not a duplicate.
        # If it is a duplicate, increase the size of the leaf and return.
        if self._is_duplicate(point, cell.barycenter):
            if self.verbose > 10:
                printf("[QuadTree] found a duplicate!\n")
            cell.cumulative_size += 1
            self.n_points += 1
            return cell_id

        # In a leaf, the barycenter correspond to the only point included
        # in it.
        self._insert_point_in_new_child(cell.barycenter, cell, cell.point_index,
                                        cell.cumulative_size)
        return self.insert_point(point, point_index, cell_id)

    # XXX: This operation is not Thread safe
    cdef intp_t _insert_point_in_new_child(
        self, float32_t[3] point, Cell* cell, intp_t point_index, intp_t size=1
    ) noexcept nogil:
        """Create a child of cell which will contain point."""

        # Local variable definition
        cdef:
            intp_t cell_id, cell_child_id, parent_id
            float32_t[3] save_point
            float32_t width
            Cell* child
            int i

        # If the maximal capacity of the Tree have been reached, double the capacity
        # We need to save the current cell id and the current point to retrieve them
        # in case the reallocation
        if self.cell_count + 1 > self.capacity:
            parent_id = cell.cell_id
            for i in range(self.n_dimensions):
                save_point[i] = point[i]
            self._resize(SIZE_MAX)
            cell = &self.cells[parent_id]
            point = save_point

        # Get an empty cell and initialize it
        cell_id = self.cell_count
        self.cell_count += 1
        child = &self.cells[cell_id]

        self._init_cell(child, cell.cell_id, cell.depth + 1)
        child.cell_id = cell_id

        # Set the cell as an inner cell of the Tree
        cell.is_leaf = False
        cell.point_index = -1

        # Set the correct boundary for the cell, store the point in the cell
        # and compute its index in the children array.
        cell_child_id = 0
        for i in range(self.n_dimensions):
            cell_child_id *= 2
            if point[i] >= cell.center[i]:
                cell_child_id += 1
                child.min_bounds[i] = cell.center[i]
                child.max_bounds[i] = cell.max_bounds[i]
            else:
                child.min_bounds[i] = cell.min_bounds[i]
                child.max_bounds[i] = cell.center[i]
            child.center[i] = (child.min_bounds[i] + child.max_bounds[i]) / 2.
            width = child.max_bounds[i] - child.min_bounds[i]

            child.barycenter[i] = point[i]
            child.squared_max_width = max(child.squared_max_width, width*width)

        # Store the point info and the size to account for duplicated points
        child.point_index = point_index
        child.cumulative_size = size

        # Store the child cell in the correct place in children
        cell.children[cell_child_id] = child.cell_id

        if DEBUGFLAG:
            # Assert that the point is in the right range
            self._check_point_in_cell(point, child)
        if self.verbose > 10:
            printf("[QuadTree] inserted point %li in new child %li\n",
                   point_index, cell_id)

        return cell_id

    cdef bint _is_duplicate(self, float32_t[3] point1, float32_t[3] point2) noexcept nogil:
        """Check if the two given points are equals."""
        cdef int i
        cdef bint res = True
        for i in range(self.n_dimensions):
            # Use EPSILON to avoid numerical error that would overgrow the tree
            res &= fabsf(point1[i] - point2[i]) <= EPSILON
        return res

    cdef intp_t _select_child(self, float32_t[3] point, Cell* cell) noexcept nogil:
        """Select the child of cell which contains the given query point."""
        cdef:
            int i
            intp_t selected_child = 0

        for i in range(self.n_dimensions):
            # Select the correct child cell to insert the point by comparing
            # it to the borders of the cells using precomputed center.
            selected_child *= 2
            if point[i] >= cell.center[i]:
                selected_child += 1
        return cell.children[selected_child]

    cdef void _init_cell(self, Cell* cell, intp_t parent, intp_t depth) noexcept nogil:
        """Initialize a cell structure with some constants."""
        cell.parent = parent
        cell.is_leaf = True
        cell.depth = depth
        cell.squared_max_width = 0
        cell.cumulative_size = 0
        for i in range(self.n_cells_per_cell):
            cell.children[i] = SIZE_MAX

    cdef void _init_root(self, float32_t[3] min_bounds, float32_t[3] max_bounds
                         ) noexcept nogil:
        """Initialize the root node with the given space boundaries"""
        cdef:
            int i
            float32_t width
            Cell* root = &self.cells[0]

        self._init_cell(root, -1, 0)
        for i in range(self.n_dimensions):
            root.min_bounds[i] = min_bounds[i]
            root.max_bounds[i] = max_bounds[i]
            root.center[i] = (max_bounds[i] + min_bounds[i]) / 2.
            width = max_bounds[i] - min_bounds[i]
            root.squared_max_width = max(root.squared_max_width, width*width)
        root.cell_id = 0

        self.cell_count += 1

    cdef int _check_point_in_cell(self, float32_t[3] point, Cell* cell
                                  ) except -1 nogil:
        """Check that the given point is in the cell boundaries."""

        if self.verbose >= 50:
            if self.n_dimensions == 3:
                printf("[QuadTree] Checking point (%f, %f, %f) in cell %li "
                       "([%f/%f, %f/%f, %f/%f], size %li)\n",
                       point[0], point[1], point[2], cell.cell_id,
                       cell.min_bounds[0], cell.max_bounds[0], cell.min_bounds[1],
                       cell.max_bounds[1], cell.min_bounds[2], cell.max_bounds[2],
                       cell.cumulative_size)
            else:
                printf("[QuadTree] Checking point (%f, %f) in cell %li "
                       "([%f/%f, %f/%f], size %li)\n",
                       point[0], point[1], cell.cell_id, cell.min_bounds[0],
                       cell.max_bounds[0], cell.min_bounds[1],
                       cell.max_bounds[1], cell.cumulative_size)

        for i in range(self.n_dimensions):
            if (cell.min_bounds[i] > point[i] or
                    cell.max_bounds[i] <= point[i]):
                with gil:
                    msg = "[QuadTree] InsertionError: point out of cell "
                    msg += "boundary.\nAxis %li: cell [%f, %f]; point %f\n"

                    msg %= i, cell.min_bounds[i],  cell.max_bounds[i], point[i]
                    raise ValueError(msg)

    def _check_coherence(self):
        """Check the coherence of the cells of the tree.

        Check that the info stored in each cell is compatible with the info
        stored in descendent and sibling cells. Raise a ValueError if this
        fails.
        """
        for cell in self.cells[:self.cell_count]:
            # Check that the barycenter of inserted point is within the cell
            # boundaries
            self._check_point_in_cell(cell.barycenter, &cell)

            if not cell.is_leaf:
                # Compute the number of point in children and compare with
                # its cummulative_size.
                n_points = 0
                for idx in range(self.n_cells_per_cell):
                    child_id = cell.children[idx]
                    if child_id != -1:
                        child = self.cells[child_id]
                        n_points += child.cumulative_size
                        assert child.cell_id == child_id, (
                            "Cell id not correctly initialized.")
                if n_points != cell.cumulative_size:
                    raise ValueError(
                        "Cell {} is incoherent. Size={} but found {} points "
                        "in children. ({})"
                        .format(cell.cell_id, cell.cumulative_size,
                                n_points, cell.children))

        # Make sure that the number of point in the tree correspond to the
        # cumulative size in root cell.
        if self.n_points != self.cells[0].cumulative_size:
            raise ValueError(
                "QuadTree is incoherent. Size={} but found {} points "
                "in children."
                .format(self.n_points, self.cells[0].cumulative_size))

    cdef long summarize(self, float32_t[3] point, float32_t* results,
                        float squared_theta=.5, intp_t cell_id=0, long idx=0
                        ) noexcept nogil:
        """Summarize the tree compared to a query point.

        Input arguments
        ---------------
        point : array (n_dimensions)
             query point to construct the summary.
        cell_id : integer, optional (default: 0)
            current cell of the tree summarized. This should be set to 0 for
            external calls.
        idx : integer, optional (default: 0)
            current index in the result array. This should be set to 0 for
            external calls
        squared_theta: float, optional (default: .5)
            threshold to decide whether the node is sufficiently far
            from the query point to be a good summary. The formula is such that
            the node is a summary if
                node_width^2 / dist_node_point^2 < squared_theta.
            Note that the argument should be passed as theta^2 to avoid
            computing square roots of the distances.

        Output arguments
        ----------------
        results : array (n_samples * (n_dimensions+2))
            result will contain a summary of the tree information compared to
            the query point:
            - results[idx:idx+n_dimensions] contains the coordinate-wise
                difference between the query point and the summary cell idx.
                This is useful in t-SNE to compute the negative forces.
            - result[idx+n_dimensions+1] contains the squared euclidean
                distance to the summary cell idx.
            - result[idx+n_dimensions+2] contains the number of point of the
                tree contained in the summary cell idx.

        Return
        ------
        idx : integer
            number of elements in the results array.
        """
        cdef:
            int i, idx_d = idx + self.n_dimensions
            bint duplicate = True
            Cell* cell = &self.cells[cell_id]

        results[idx_d] = 0.
        for i in range(self.n_dimensions):
            results[idx + i] = point[i] - cell.barycenter[i]
            results[idx_d] += results[idx + i] * results[idx + i]
            duplicate &= fabsf(results[idx + i]) <= EPSILON

        # Do not compute self interactions
        if duplicate and cell.is_leaf:
            return idx

        # Check whether we can use this node as a summary
        # It's a summary node if the angular size as measured from the point
        # is relatively small (w.r.t. theta) or if it is a leaf node.
        # If it can be summarized, we use the cell center of mass
        # Otherwise, we go a higher level of resolution and into the leaves.
        if cell.is_leaf or (
                (cell.squared_max_width / results[idx_d]) < squared_theta):
            results[idx_d + 1] = <float32_t> cell.cumulative_size
            return idx + self.n_dimensions + 2

        else:
            # Recursively compute the summary in nodes
            for c in range(self.n_cells_per_cell):
                child_id = cell.children[c]
                if child_id != -1:
                    idx = self.summarize(point, results, squared_theta,
                                         child_id, idx)

        return idx

    def get_cell(self, point):
        """return the id of the cell containing the query point or raise
        ValueError if the point is not in the tree
        """
        cdef float32_t[3] query_pt
        cdef int i

        assert len(point) == self.n_dimensions, (
            "Query point should be a point in dimension {}."
            .format(self.n_dimensions))

        for i in range(self.n_dimensions):
            query_pt[i] = point[i]

        return self._get_cell(query_pt, 0)

    cdef int _get_cell(self, float32_t[3] point, intp_t cell_id=0
                       ) except -1 nogil:
        """guts of get_cell.

        Return the id of the cell containing the query point or raise ValueError
        if the point is not in the tree"""
        cdef:
            intp_t selected_child
            Cell* cell = &self.cells[cell_id]

        if cell.is_leaf:
            if self._is_duplicate(cell.barycenter, point):
                if self.verbose > 99:
                    printf("[QuadTree] Found point in cell: %li\n",
                           cell.cell_id)
                return cell_id
            with gil:
                raise ValueError("Query point not in the Tree.")

        selected_child = self._select_child(point, cell)
        if selected_child > 0:
            if self.verbose > 99:
                printf("[QuadTree] Selected_child: %li\n", selected_child)
            return self._get_cell(point, selected_child)
        with gil:
            raise ValueError("Query point not in the Tree.")

    # Pickling primitives

    def __reduce__(self):
        """Reduce re-implementation, for pickling."""
        return (_QuadTree, (self.n_dimensions, self.verbose), self.__getstate__())

    def __getstate__(self):
        """Getstate re-implementation, for pickling."""
        d = {}
        # capacity is inferred during the __setstate__ using nodes
        d["max_depth"] = self.max_depth
        d["cell_count"] = self.cell_count
        d["capacity"] = self.capacity
        d["n_points"] = self.n_points
        d["cells"] = self._get_cell_ndarray().base
        return d

    def __setstate__(self, d):
        """Setstate re-implementation, for unpickling."""
        self.max_depth = d["max_depth"]
        self.cell_count = d["cell_count"]
        self.capacity = d["capacity"]
        self.n_points = d["n_points"]

        if 'cells' not in d:
            raise ValueError('You have loaded Tree version which '
                             'cannot be imported')

        cell_ndarray = d['cells']

        if (cell_ndarray.ndim != 1 or
                cell_ndarray.dtype != CELL_DTYPE or
                not cell_ndarray.flags.c_contiguous):
            raise ValueError('Did not recognise loaded array layout')

        self.capacity = cell_ndarray.shape[0]
        if self._resize_c(self.capacity) != 0:
            raise MemoryError("resizing tree to %d" % self.capacity)

        cdef Cell[:] cell_mem_view = cell_ndarray
        memcpy(
            pto=self.cells,
            pfrom=&cell_mem_view[0],
            size=self.capacity * sizeof(Cell),
        )

    # Array manipulation methods, to convert it to numpy or to resize
    # self.cells array

    cdef Cell[:] _get_cell_ndarray(self):
        """Wraps nodes as a NumPy struct array.

        The array keeps a reference to this Tree, which manages the underlying
        memory. Individual fields are publicly accessible as properties of the
        Tree.
        """
        cdef cnp.npy_intp shape[1]
        shape[0] = <cnp.npy_intp> self.cell_count
        cdef cnp.npy_intp strides[1]
        strides[0] = sizeof(Cell)
        cdef Cell[:] arr
        Py_INCREF(CELL_DTYPE)
        arr = PyArray_NewFromDescr(
            subtype=<PyTypeObject *> np.ndarray,
            descr=CELL_DTYPE,
            nd=1,
            dims=shape,
            strides=strides,
            data=<void*> self.cells,
            flags=cnp.NPY_ARRAY_DEFAULT,
            obj=None,
        )
        Py_INCREF(self)
        if PyArray_SetBaseObject(arr.base, <PyObject*> self) < 0:
            raise ValueError("Can't initialize array!")
        return arr

    cdef int _resize(self, intp_t capacity) except -1 nogil:
        """Resize all inner arrays to `capacity`, if `capacity` == -1, then
           double the size of the inner arrays.

        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        if self._resize_c(capacity) != 0:
            # Acquire gil only if we need to raise
            with gil:
                raise MemoryError()

    cdef int _resize_c(self, intp_t capacity=SIZE_MAX) except -1 nogil:
        """Guts of _resize

        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        if capacity == self.capacity and self.cells != NULL:
            return 0

        if <size_t> capacity == SIZE_MAX:
            if self.capacity == 0:
                capacity = 9  # default initial value to min
            else:
                capacity = 2 * self.capacity

        safe_realloc(&self.cells, capacity)

        # if capacity smaller than cell_count, adjust the counter
        if capacity < self.cell_count:
            self.cell_count = capacity

        self.capacity = capacity
        return 0

    def _py_summarize(self, float32_t[:] query_pt, float32_t[:, :] X, float angle):
        # Used for testing summarize
        cdef:
            float32_t[:] summary
            int n_samples

        n_samples = X.shape[0]
        summary = np.empty(4 * n_samples, dtype=np.float32)

        idx = self.summarize(&query_pt[0], &summary[0], angle * angle)
        return idx, summary
