#!/usr/bin/env python3
import glob
import os
import subprocess
import sys
from dataclasses import dataclass
from dataclasses import field
from typing import Any

# We no longer define 'std' or 'musl' manually.
# We rely on -fimplicit-module-maps to find the system-provided ones.
# We only define modules for our project and third-party libraries.
EXTRA_MODULES = """
module "_system" [system] {
  header "/usr/include/stdint.h"
  header "/usr/include/stdbool.h"
  header "/usr/include/stddef.h"
  header "/usr/include/stdio.h"
  header "/usr/include/stdlib.h"
  header "/usr/include/string.h"
  header "/usr/include/inttypes.h"
  header "/usr/include/limits.h"
  header "/usr/include/errno.h"
  header "/usr/include/unistd.h"
  header "/usr/include/fcntl.h"
  header "/usr/include/time.h"
  header "/usr/include/assert.h"
  header "/usr/include/pthread.h"
  header "/usr/include/arpa/inet.h"
  header "/usr/include/netdb.h"
  header "/usr/include/netinet/in.h"
  header "/usr/include/sys/types.h"
  header "/usr/include/sys/stat.h"
  header "/usr/include/sys/time.h"
  header "/usr/include/sys/socket.h"
  header "/usr/include/sys/epoll.h"
  header "/usr/include/sys/ioctl.h"
  header "/usr/include/sys/resource.h"
  header "/usr/include/sys/uio.h"
  header "/usr/include/sys/mman.h"
  header "/usr/include/sys/wait.h"
  header "/usr/include/sys/select.h"
  header "/usr/include/poll.h"
  header "/usr/include/sched.h"
  header "/usr/include/signal.h"
  header "/usr/include/ctype.h"
  header "/usr/include/alloca.h"
  header "/usr/include/malloc.h"
  header "/usr/include/dirent.h"
  export *
}
module "_libsodium" [system] {
  header "/usr/include/sodium.h"
  use _system
  export *
}
module "_benchmark" [system] {
  header "/usr/include/benchmark/benchmark.h"
  use _system
  export *
}
module "_com_google_googletest___gtest" [system] {
  header "/usr/include/gtest/gtest.h"
  header "/usr/include/gmock/gmock.h"
  use _system
  export *
}
module "_com_google_googletest___gtest_main" [system] {
    use _com_google_googletest___gtest
    export *
}
module "_opus" [system] {
  header "/usr/include/opus/opus.h"
  use _system
  export *
}
module "_libvpx" [system] {
  header "/usr/include/vpx/vpx_encoder.h"
  header "/usr/include/vpx/vpx_decoder.h"
  use _system
  export *
}
"""


@dataclass
class Target:
    name: str
    package: str
    srcs: list[str] = field(default_factory=list)
    hdrs: list[str] = field(default_factory=list)
    deps: list[str] = field(default_factory=list)

    @property
    def label(self) -> str:
        return f"c-toxcore/{self.package}:{self.name}"


TARGETS: list[Target] = []


class BuildContext:

    def __init__(self, package: str):
        self.package = package

    def bzl_load(self, *args: Any, **kwargs: Any) -> None:
        pass

    def bzl_exports_files(self, *args: Any, **kwargs: Any) -> None:
        pass

    def bzl_alias(self, *args: Any, **kwargs: Any) -> None:
        pass

    def bzl_sh_library(self, *args: Any, **kwargs: Any) -> None:
        pass

    def bzl_cc_fuzz_test(self, *args: Any, **kwargs: Any) -> None:
        pass

    def bzl_select(self, selector: dict[str, list[str]]) -> list[str]:
        return selector.get("//tools/config:linux",
                            selector.get("//conditions:default", []))

    def bzl_glob(self,
                 include: list[str],
                 exclude: list[str] | None = None,
                 **kwargs: Any) -> list[str]:
        results = []
        for pattern in include:
            full_pattern = os.path.join(self.package, pattern)
            files = glob.glob(full_pattern, recursive=True)
            results.extend([os.path.relpath(f, self.package) for f in files])

        if exclude:
            excluded_files = set()
            for pattern in exclude:
                full_pattern = os.path.join(self.package, pattern)
                files = glob.glob(full_pattern, recursive=True)
                excluded_files.update(
                    [os.path.relpath(f, self.package) for f in files])
            results = [f for f in results if f not in excluded_files]

        return results

    def _add_target(self, name: str, srcs: Any, hdrs: Any, deps: Any) -> None:
        srcs = list(srcs) if srcs else []
        hdrs = list(hdrs) if hdrs else []
        deps = list(deps) if deps else []
        TARGETS.append(Target(name, self.package, srcs, hdrs, deps))

    def bzl_cc_library(self,
                       name: str,
                       srcs: Any = (),
                       hdrs: Any = (),
                       deps: Any = (),
                       **kwargs: Any) -> None:
        self._add_target(name, srcs, hdrs, deps)

    def bzl_cc_binary(self,
                      name: str,
                      srcs: Any = (),
                      hdrs: Any = (),
                      deps: Any = (),
                      **kwargs: Any) -> None:
        self._add_target(name, srcs, hdrs, deps)

    def bzl_project(self, *args: Any, **kwargs: Any) -> None:
        pass

    def bzl_cc_test(self,
                    name: str,
                    srcs: Any = (),
                    hdrs: Any = (),
                    deps: Any = (),
                    **kwargs: Any) -> None:
        self._add_target(name, srcs, hdrs, deps)


def resolve_module_name(dep: str, current_pkg: str) -> str:
    if dep in ["@psocket", "@pthread"]:
        return "_system"
    if dep == "@libsodium":
        return "_libsodium"
    if dep == "@benchmark":
        return "_benchmark"
    if dep == "@com_google_googletest//:gtest":
        return "_com_google_googletest___gtest"
    if dep == "@com_google_googletest//:gtest_main":
        return "_com_google_googletest___gtest_main"
    if dep == "@opus":
        return "_opus"
    if dep == "@libvpx":
        return "_libvpx"

    # Resolve to canonical label first
    if dep.startswith("@"):
        return dep
    if dep.startswith("//"):
        label = dep[2:]
        if ":" in label:
            return label
        pkg_name = os.path.basename(label)
        return f"{label}:{pkg_name}"
    if dep.startswith(":"):
        return f"c-toxcore/{current_pkg}{dep}"

    return dep


def main() -> None:
    packages = []
    for root, dirs, files in os.walk("."):
        if "BUILD.bazel" in files:
            pkg = os.path.relpath(root, ".")
            if pkg == ".":
                pkg = ""
            packages.append(pkg)

    for pkg in packages:
        ctx = BuildContext(pkg)
        build_file = os.path.join(pkg, "BUILD.bazel")
        if not os.path.exists(build_file):
            continue

        # Use a defaultdict to handle any unknown functions as no-ops
        from collections import defaultdict

        env: dict[str, Any] = defaultdict(lambda: lambda *args, **kwargs: None)
        env.update({
            "load": ctx.bzl_load,
            "exports_files": ctx.bzl_exports_files,
            "cc_library": ctx.bzl_cc_library,
            "no_undefined_cc_library": ctx.bzl_cc_library,
            "cc_binary": ctx.bzl_cc_binary,
            "cc_test": ctx.bzl_cc_test,
            "cc_fuzz_test": ctx.bzl_cc_fuzz_test,
            "select": ctx.bzl_select,
            "glob": ctx.bzl_glob,
            "alias": ctx.bzl_alias,
            "sh_library": ctx.bzl_sh_library,
            "project": ctx.bzl_project,
        })

        with open(build_file, "r") as f:
            exec(f.read(), env)

    with open("module.modulemap", "w") as f:
        f.write(EXTRA_MODULES)
        for t in TARGETS:
            f.write(f'module "{t.label}" {{\n')
            for hdr in t.hdrs:
                # Proper modular header
                f.write(f'  header "{os.path.join(t.package, hdr)}"\n')

            # Use all dependencies
            for dep in t.deps:
                mod_name = resolve_module_name(dep, t.package)
                f.write(f'  use "{mod_name}"\n')

            # Re-export everything we use to match standard C++ transitive include behavior.
            f.write("  export *\n")

            # Basic system modules used everywhere
            f.write('  use "_system"\n')

            f.write("}\n")

    if "--print-modulemap" in sys.argv:
        with open("module.modulemap", "r") as f:
            print(f.read())
        return

    src_to_module = {}
    for t in TARGETS:
        for src in t.srcs:
            full_src = os.path.join(t.package, src)
            src_to_module[full_src] = t.label

    all_srcs = sorted(src_to_module.keys())
    all_srcs = [src for src in all_srcs if src.endswith(".c")]
    os.makedirs("/tmp/clang-modules", exist_ok=True)

    for src in all_srcs:
        print(f"Validating {src}", file=sys.stderr)
        module_name = src_to_module[src]

        lang = "-xc" if src.endswith(".c") else "-xc++"
        std = "-std=c11" if src.endswith(".c") else "-std=c++23"

        subprocess.run(
            [
                "clang",
                "-fsyntax-only",
                lang,
                "-stdlib=libc++" if lang == "-xc++" else "-nostdinc++",
                "-Wall",
                "-Werror",
                "-Wno-missing-braces",
                "-DTCP_SERVER_USE_EPOLL",
                "-D_XOPEN_SOURCE=600",
                "-D_GNU_SOURCE",
                std,
                "-fdiagnostics-color=always",
                "-fmodules",
                "-Xclang",
                "-fmodules-local-submodule-visibility",
                "-fmodules-decluse",
                "-Xclang",
                "-fno-modules-error-recovery",
                "-fno-implicit-module-maps",
                "-fno-builtin-module-map",
                "-fmodules-cache-path=/tmp/clang-modules",
                "-fmodule-map-file=module.modulemap",
                f"-fmodule-name={module_name}",
                "-I.",
                "-I/usr/include/opus",
                src,
            ],
            check=True,
        )


if __name__ == "__main__":
    main()
