#!/usr/bin/env python3

# Copyright 2025 Google LLC
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import difflib
import os
import shutil
import subprocess
import sys
import tempfile
from contextlib import contextmanager

@contextmanager
def get_folder(path=None):
    if path:
        yield path
    else:
        with tempfile.TemporaryDirectory() as temp_file:
            yield temp_file

# Third party dependencies necessary for Dawn to generate its version of the header.
THIRD_PARTY_DIR = os.path.join('third_party')
JINJA2_DIR = os.path.join('third_party', 'jinja2')
MARKUPSAFE_DIR = os.path.join('third_party', 'markupsafe')

# Dawn generator constants.
DAWN_GENERATOR = os.path.join('generator', 'dawn_json_generator.py')
DAWN_GENERATOR_TEMPLATE_DIR = os.path.join('generator', 'templates')
DAWN_JSON = os.path.join('src', 'dawn', 'dawn.json')
DAWN_WEBGPU_HEADER_TARGET = 'webgpu_headers'

# Upstream webgpu-header constants.
WEBGPU_HEADER = os.path.join('third_party', 'webgpu-headers', 'src',
                             'webgpu.h')

# Default output file locations.
DAWN_GENERATED_HEADER = os.path.join('webgpu-headers', 'webgpu.h')
WEBGPU_GENERATED_HEADER = os.path.join('webgpu-headers', 'webgpu_upstream.h')
GENERATED_DIFF = os.path.join('webgpu-headers', 'webgpu.h.diff')


# Helper that removes all comments from headers so that they can be diff-ed easier.
def _remove_comments(input_path, output_path):
    lines = []
    with open(input_path, 'r') as input:
        last_line_was_comment = False
        in_comment_block = False
        for line in input:
            stripped = line.strip()

            # If we are in a comment block, skip until we find the end.
            if in_comment_block:
                if stripped.endswith("*/"):
                    in_comment_block = False
                    last_line_was_comment = True
                continue

            # Drop empty lines directly following comments.
            if not stripped and last_line_was_comment:
                continue

            if stripped.startswith('//'):
                # Simple single line comment case, just skip the line.
                last_line_was_comment = True
                continue

            if stripped.startswith("/*") and stripped.endswith("*/"):
                # Simple single line comment with multi-line syntax, just skip the line.
                last_line_was_comment = True
                continue

            if stripped.startswith("/*") and "*/" not in stripped:
                # Otherwise, if we are a multi-line comment that isn't inlined.
                in_comment_block = True
                continue

            last_line_was_comment = False
            lines.append(line)

    with open(output_path, 'w') as output:
        for line in lines:
            output.write(line)


# Helper that just calls into Dawn's generators and generates webgpu.h without
# comments.
def _gen_dawn_header(dir):
    subprocess.check_call([
        sys.executable, DAWN_GENERATOR, '--template-dir',
        DAWN_GENERATOR_TEMPLATE_DIR, '--output-dir', dir, '--dawn-json',
        DAWN_JSON, '--targets', DAWN_WEBGPU_HEADER_TARGET, '--jinja2-path',
        JINJA2_DIR, '--markupsafe-path', MARKUPSAFE_DIR
    ])
    _remove_comments(os.path.join(dir, DAWN_GENERATED_HEADER),
                     os.path.join(dir, DAWN_GENERATED_HEADER))


# Helper that removes comments from the upstream webgpu.h file.
def _gen_webgpu_header(dir):
    _remove_comments(WEBGPU_HEADER, os.path.join(dir, WEBGPU_GENERATED_HEADER))


# Helper that generates the diff between the Dawn and upstream headers.
def _gen_header_diff(temp_dir, output_dir):
    _gen_dawn_header(temp_dir)
    _gen_webgpu_header(temp_dir)

    with open(os.path.join(temp_dir, DAWN_GENERATED_HEADER), 'r') as dawn_header, \
         open(os.path.join(temp_dir, WEBGPU_GENERATED_HEADER), 'r') as webgpu_header, \
         open(os.path.join(output_dir, GENERATED_DIFF), 'w') as diff_file:
        diff = difflib.unified_diff(webgpu_header.readlines(),
                                    dawn_header.readlines(),
                                    fromfile='webgpu_header',
                                    tofile='dawn_header', n=0)
        # Skip writing all the line numbers as these all change every time
        # anything in the file changes.
        for line in diff:
            if line.startswith('@@'):
                diff_file.writelines('@@\n')
            else:
                diff_file.writelines(line)


# Helper that verifies that all the necessary dependencies to run checks are
# available. If they are not, we just silently pass for now.
def _check_deps():
    if not os.path.isdir(JINJA2_DIR):
        print(
            f'Third party dependency jinja2 is not available at {JINJA2_DIR}')
        return False
    if not os.path.isdir(MARKUPSAFE_DIR):
        print(
            f'Third party dependency markupsafe is not available at {MARKUPSAFE_DIR}'
        )
        return False
    if not os.path.exists(WEBGPU_HEADER):
        print(
            f'Third party upstream webgpu.h is not available at {WEBGPU_HEADER}'
        )
        return False
    return True


def main(args):
    parser = argparse.ArgumentParser()
    parser.add_argument('--outdir', default=None)
    parser.add_argument('action', choices=('gen', 'check'))
    options = parser.parse_args()

    if options.action == 'gen':
        with get_folder(options.outdir) as temp_dir:
            _gen_header_diff(temp_dir, THIRD_PARTY_DIR)
        return 0

    if options.action == 'check':
        if not _check_deps():
            return 0
        with get_folder(options.outdir) as temp_dir:
            _gen_header_diff(temp_dir, temp_dir)

            with open(os.path.join(temp_dir, GENERATED_DIFF), 'r') as temp_file, \
                 open(os.path.join(THIRD_PARTY_DIR, GENERATED_DIFF), 'r') as curr_file:
                diff = ''.join(
                    difflib.unified_diff(temp_file.readlines(),
                                         curr_file.readlines()))
                if diff:
                    print(
                        "Changes resulted in new diffs between Dawn's and the "
                        "upstream webgpu.h. Please re-run "
                        "'third_party/webgpu-headers/cli gen' manually and "
                        "verify/commit the new diffs in "
                        "'third_party/webgpu-headers/webgpu.h.diff' if the "
                        "changes are expected.")
                    return 1

        return 0


if __name__ == '__main__':
    sys.exit(main(sys.argv))
