"""Test that C functions used in primitives are declared in a header such as CPy.h."""

from __future__ import annotations

import glob
import os
import re
import unittest

from mypyc.ir.deps import SourceDep
from mypyc.ir.ops import PrimitiveDescription
from mypyc.primitives import (
    bytearray_ops,
    bytes_ops,
    dict_ops,
    exc_ops,
    float_ops,
    generic_ops,
    int_ops,
    librt_strings_ops,
    librt_vecs_ops,
    list_ops,
    misc_ops,
    registry,
    set_ops,
    str_ops,
    tuple_ops,
    weakref_ops,
)


class TestHeaderInclusion(unittest.TestCase):
    def test_primitives_included_in_header(self) -> None:
        base_dir = os.path.join(os.path.dirname(__file__), "..", "lib-rt")
        with open(os.path.join(base_dir, "CPy.h")) as f:
            header = f.read()
        with open(os.path.join(base_dir, "pythonsupport.h")) as f:
            header += f.read()

        def check_name(name: str) -> None:
            if name.startswith("CPy"):
                assert re.search(
                    rf"\b{name}\b", header
                ), f'"{name}" is used in mypyc.primitives but not declared in CPy.h'

        all_ops = []
        for values in [
            registry.method_call_ops.values(),
            registry.binary_ops.values(),
            registry.unary_ops.values(),
            registry.function_ops.values(),
        ]:
            for ops in values:
                all_ops.extend(ops)

        for module in [
            bytes_ops,
            str_ops,
            dict_ops,
            list_ops,
            bytearray_ops,
            generic_ops,
            int_ops,
            misc_ops,
            tuple_ops,
            exc_ops,
            float_ops,
            set_ops,
            weakref_ops,
            librt_vecs_ops,
            librt_strings_ops,
        ]:
            for name in dir(module):
                val = getattr(module, name, None)
                if isinstance(val, PrimitiveDescription):
                    all_ops.append(val)

        # Find additional headers via extra C source file dependencies.
        for op in all_ops:
            if op.dependencies:
                for dep in op.dependencies:
                    if isinstance(dep, SourceDep):
                        header_fnam = os.path.join(base_dir, dep.get_header())
                        if os.path.isfile(header_fnam):
                            with open(os.path.join(base_dir, header_fnam)) as f:
                                header += f.read()

        for op in all_ops:
            if op.c_function_name is not None:
                check_name(op.c_function_name)

        primitives_path = os.path.join(os.path.dirname(__file__), "..", "primitives")
        for fnam in glob.glob(f"{primitives_path}/*.py"):
            with open(fnam) as f:
                content = f.read()
            for name in re.findall(r'c_function_name=["\'](CPy[A-Z_a-z0-9]+)', content):
                check_name(name)
