Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ utils / hipify / hipify_python.py

#!/usr/bin/env python3
""" The Python Hipify script.
# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
#               2017-2018 Advanced Micro Devices, Inc. and
#                         Facebook Inc. All rights reserved.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
import argparse
import fnmatch
import re
import shutil
import sys
import os

from . import constants
from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
from .cuda_to_hip_mappings import MATH_TRANSPILATIONS

from typing import Dict, List, Iterator, Optional
from collections.abc import Mapping, Iterable
HipifyResult = Dict[str, Optional[str]]
HipifyFinalResult = Dict[str, HipifyResult]
HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
HIPIFY_FINAL_RESULT: HipifyFinalResult = {}

# Hardcode the PyTorch template map
"""This dictionary provides the mapping from PyTorch kernel template types
to their actual types."""
PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}

__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter',
           'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
           'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
           'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_caffe2_gpu_file',
           'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
           'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'hipify']

class InputError(Exception):
    # Exception raised for errors in the input.

    def __init__(self, message):
        self.message = message

    def __str__(self):
        return "{}: {}".format("Input error", self.message)

def openf(filename, mode):
    return open(filename, mode, errors='ignore')

# Color coding for printing
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

# To the programmer, the output of hipify most likely are intermediates.
# This class allows users of hipify to ask for a cleanup by running the
# hipify and compilation in a with instantiating this context manager class
# with keep_intermediates=False.
# The main usecase is the cpp_extensions, specifically the load method.
# It is a good idea to keep intermediates (in case of errors or to
# not recompile unchanged files), but in cases where you don't want to
# keep them (e.g. in the CI), this can be used to remove files.
class GeneratedFileCleaner:
    """Context Manager to clean up generated files"""
    def __init__(self, keep_intermediates=False):
        self.keep_intermediates = keep_intermediates
        self.files_to_clean = set()
        self.dirs_to_clean = []

    def __enter__(self):
        return self

    def open(self, fn, *args, **kwargs):
        if not os.path.exists(fn):
        return open(fn, *args, **kwargs)

    def makedirs(self, dn, exist_ok=False):
        parent, n = os.path.split(dn)
        if not n:
            parent, n = os.path.split(parent)
        if parent and n and not os.path.exists(parent):
            self.makedirs(parent, exist_ok=True)
        if not os.path.isdir(dn) or not exist_ok:

    def __exit__(self, type, value, traceback):
        if not self.keep_intermediates:
            for f in self.files_to_clean:
            for d in self.dirs_to_clean[::-1]:

def match_extensions(filename: str, extensions: Iterable) -> bool:
    """Helper method to see if filename ends with certain extension"""
    return any(filename.endswith(e) for e in extensions)

def _fnmatch(filepath, patterns):
    return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)

def matched_files_iter(
        root_path: str,
        includes: Iterable = (),
        ignores: Iterable = (),
        extensions: Iterable = (),
        out_of_place_only: bool = False,
        is_pytorch_extension: bool = False) -> Iterator[str]:

    exact_matches = set(includes)

    # This is a very rough heuristic; really, we want to avoid scanning
    # any file which is not checked into source control, but this script
    # needs to work even if you're in a Git or Hg checkout, so easier to
    # just block the biggest time sinks that won't matter in the
    # end.
    for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True):
        rel_dirpath = os.path.relpath(abs_dirpath, root_path)
        if rel_dirpath == '.':
            # Blah blah blah O(n) blah blah
            if ".git" in dirs:
            if "build" in dirs:
            if "third_party" in dirs:
        for filename in filenames:
            filepath = os.path.join(abs_dirpath, filename)
            rel_filepath = os.path.join(rel_dirpath, filename)
            # We respect extensions, UNLESS you wrote the entire
            # filename verbatim, in which case we always accept it
            if (
                _fnmatch(filepath, includes)
                and (not _fnmatch(filepath, ignores))
                and (match_extensions(filepath, extensions) or filepath in exact_matches)
                if not is_pytorch_extension:  # for pytorch extensions, consider all files
                    if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath):
                    if out_of_place_only and not is_out_of_place(rel_filepath):
                yield filepath

def preprocess_file_and_save_result(
        output_directory: str,
        filepath: str,
        all_files: Iterable,
        header_include_dirs: Iterable,
        stats: Dict[str, List],
        hip_clang_launch: bool,
        is_pytorch_extension: bool,
        clean_ctx: GeneratedFileCleaner,
        show_progress: bool) -> None:
    result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats,
                          hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)

    fin_path = os.path.abspath(os.path.join(output_directory, filepath))
    # Show what happened
    if show_progress and "ignored" not in str(result["status"]):
            fin_path, "->",
            result["hipified_path"], result["status"], flush=True)

    HIPIFY_FINAL_RESULT[fin_path] = result

def compute_stats(stats):
    unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}

    # Print the number of unsupported calls
    print("Total number of unsupported CUDA function calls: {0:d}".format(len(unsupported_calls)))

    # Print the list of unsupported calls
    print(", ".join(unsupported_calls))

    # Print the number of kernel launches
    print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"])))

def add_dim3(kernel_string, cuda_kernel):
    '''adds dim3() to the second and third arguments in the kernel launch'''
    count = 0
    closure = 0
    kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
    arg_locs: List[Dict[str, int]] = [{} for _ in range(2)]
    arg_locs[count]['start'] = 0
    for ind, c in enumerate(kernel_string):
        if count > 1:
        if c == "(":
            closure += 1
        elif c == ")":
            closure -= 1
        if (c == "," or ind == len(kernel_string) - 1) and closure == 0:
            arg_locs[count]['end'] = ind + (c != ",")
            count += 1
            if count < 2:
                arg_locs[count]['start'] = ind + 1

    first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1]
    second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']]

    first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
    second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")

    first_arg_dim3 = "dim3({})".format(first_arg_clean)
    second_arg_dim3 = "dim3({})".format(second_arg_clean)

    first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
    second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
    cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
    return cuda_kernel

RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')

def processKernelLaunches(string, stats):
    """ Replace the CUDA style Kernel launches with the HIP style kernel launches."""
    # Concat the namespace with the kernel names. (Find cleaner way of doing this later).
    string = RE_KERNEL_LAUNCH.sub(lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)

    def grab_method_and_template(in_kernel):
        # The positions for relevant kernel components.
        pos = {
            "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]},
            "kernel_name": {"start": -1, "end": -1},
            "template": {"start": -1, "end": -1}

        # Count for balancing template
        count = {"<>": 0}

        # Status for whether we are parsing a certain item.
        START = 0
        AT_TEMPLATE = 1
        AFTER_TEMPLATE = 2
        AT_KERNEL_NAME = 3

        status = START

        # Parse the string character by character
        for i in range(pos["kernel_launch"]["start"] - 1, -1, -1):
            char = string[i]

            # Handle Templating Arguments
            if status == START or status == AT_TEMPLATE:
                if char == ">":
                    if status == START:
                        status = AT_TEMPLATE
                        pos["template"]["end"] = i
                    count["<>"] += 1

                if char == "<":
                    count["<>"] -= 1
                    if count["<>"] == 0 and (status == AT_TEMPLATE):
                        pos["template"]["start"] = i
                        status = AFTER_TEMPLATE

            # Handle Kernel Name
            if status != AT_TEMPLATE:
                if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}:
                    if status != AT_KERNEL_NAME:
                        status = AT_KERNEL_NAME
                        pos["kernel_name"]["end"] = i

                    # Case: Kernel name starts the string.
                    if i == 0:
                        pos["kernel_name"]["start"] = 0

                        # Finished
                        return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]

                    # Potential ending point if we're already traversing a kernel's name.
                    if status == AT_KERNEL_NAME:
                        pos["kernel_name"]["start"] = i

                        # Finished
                        return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]

    def find_kernel_bounds(string):
        """Finds the starting and ending points for all kernel launches in the string."""
        kernel_end = 0
        kernel_positions = []

        # Continue until we cannot find any more kernels anymore.
        while string.find("<<<", kernel_end) != -1:
            # Get kernel starting position (starting from the previous ending point)
            kernel_start = string.find("<<<", kernel_end)

            # Get kernel ending position (adjust end point past the >>>)
            kernel_end = string.find(">>>", kernel_start) + 3
            if kernel_end <= 0:
                raise InputError("no kernel end found")

            # Add to list of traversed kernels
            kernel_positions.append({"start": kernel_start, "end": kernel_end,
                                     "group": string[kernel_start: kernel_end]})

        return kernel_positions

    # Replace comments and string literals from the code so that find_kernel_bounds does not
    # wrongly capture kernels in comments and string literals.
    # This function replaces them with "x" to keep positions.
    def mask_comments(string):
        in_comment = ''
        prev_c = ''
        new_string = ''
        for c in string:
            if in_comment == '':
                # Outside comments
                if c == '/' and prev_c == '/':
                    in_comment = '//'
                elif c == '*' and prev_c == '/':
                    in_comment = '/*'
                elif c == '"' and prev_c != '\\' and prev_c != "'":
                    in_comment = '"'
            elif in_comment == '//':
                # In // xxx
                if c == '\r' or c == '\n':
                    in_comment = ''
            elif in_comment == '/*':
                # In /* xxx */
                if c == '/' and prev_c == '*':
                    in_comment = ''
            elif in_comment == '"':
                # In ""
                if c == '"' and prev_c != '\\':
                    in_comment = ''
            prev_c = c
            if in_comment == '':
                new_string += c
                new_string += 'x'
        return new_string

    # Grab positional ranges of all kernel launches
    get_kernel_positions = list(find_kernel_bounds(mask_comments(string)))
    output_string = string

    # Replace each CUDA kernel with a HIP kernel.
    for kernel in get_kernel_positions:
        # Get kernel components
        params = grab_method_and_template(kernel)

        # Find parenthesis after kernel launch
        parenthesis = string.find("(", kernel["end"])

        # Extract cuda kernel
        cuda_kernel = string[params[0]["start"]:parenthesis + 1]
        kernel_string = string[kernel['start']:kernel['end']]
        end_param_index = 0 if params[1]['end'] == -1 else 1
        kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1]
        cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel)
        # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
        num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")")))

        hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
            ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace(
            ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")")

        # Replace cuda kernel with hip kernel
        output_string = output_string.replace(cuda_kernel, hip_kernel)

        # Update the statistics

    return output_string

def find_closure_group(input_string, start, group):
    """Generalization for finding a balancing closure group

         if group = ["(", ")"], then finds the first balanced parentheses.
         if group = ["{", "}"], then finds the first balanced bracket.

    Given an input string, a starting position in the input string, and the group type,
    find_closure_group returns the positions of group[0] and group[1] as a tuple.

        >>> find_closure_group("(hi)", 0, ["(", ")"])
        (0, 3)

    inside_parenthesis = False
    parens = 0
    pos = start
    p_start, p_end = -1, -1

    while pos < len(input_string):
        if input_string[pos] == group[0]:
            if inside_parenthesis is False:
                inside_parenthesis = True
                parens = 1
                p_start = pos
                parens += 1
        elif input_string[pos] == group[1] and inside_parenthesis:
            parens -= 1

            if parens == 0:
                p_end = pos
                return p_start, p_end

        pos += 1
    return None, None

def find_bracket_group(input_string, start):
    """Finds the first balanced parantheses."""
    return find_closure_group(input_string, start, group=["{", "}"])

def find_parentheses_group(input_string, start):
    """Finds the first balanced bracket."""
    return find_closure_group(input_string, start, group=["(", ")"])

RE_ASSERT = re.compile(r"\bassert[ ]*\(")

def replace_math_functions(input_string):
    """FIXME: Temporarily replace std:: invocations of math functions
        with non-std:: versions to prevent linker errors NOTE: This
        can lead to correctness issues when running tests, since the
        correct version of the math function (exp/expf) might not get
        called.  Plan is to remove this function once HIP supports
        std:: math function calls inside device code

    output_string = input_string
    for func in MATH_TRANSPILATIONS:
        output_string = output_string.replace(r'{}('.format(func), '{}('.format(MATH_TRANSPILATIONS[func]))

    return output_string

RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()")

def hip_header_magic(input_string):
    """If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
    then automatically add an #include to match the "magic" includes provided by NVCC.
        Update logic to ignore cases where the cuda_runtime.h is included by another file.

    # Copy the input.
    output_string = input_string

    # Check if one of the following headers is already included.
    headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"]
    if any(re.search(r'#include ("{0}"|<{0}>)'.format(ext), output_string) for ext in headers):
        return output_string

    # Rough logic to detect if we're inside device code
    hasDeviceLogic: int
    hasDeviceLogic = "hipLaunchKernelGGL" in output_string
    hasDeviceLogic += "__global__" in output_string
    hasDeviceLogic += "__shared__" in output_string
    hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None

    # If device logic found, provide the necessary header.
    if hasDeviceLogic:
        output_string = '#include "hip/hip_runtime.h"\n' + input_string

    return output_string

RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;")

def replace_extern_shared(input_string):
    """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
        "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
        "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
    output_string = input_string
    output_string = RE_EXTERN_SHARED.sub(
        lambda inp: "HIP_DYNAMIC_SHARED({0} {1}, {2})".format(
            inp.group(1) or "", inp.group(2), inp.group(3)), output_string)

    return output_string

def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
    Returns the new name of the hipified file
    # At the moment, some PyTorch source files are HIPified in place.  The predicate
    # is_out_of_place tells us if this is the case or not.
    assert not os.path.isabs(rel_filepath)
    if not is_pytorch_extension and not is_out_of_place(rel_filepath):
        return rel_filepath

    dirpath, filename = os.path.split(rel_filepath)
    root, ext = os.path.splitext(filename)

    # Here's the plan:
    # In general, we need to disambiguate the HIPified filename so that
    # it gets a different name from the original filename, so
    # that we don't overwrite the original file
    # There's a lot of different naming conventions across PyTorch
    # and Caffe2, but the general recipe is to convert occurrences
    # of cuda/gpu to hip, and add hip if there are no occurrences
    # of cuda/gpu anywhere.
    # Concretely, we do the following:
    #   - If there is a directory component named "cuda", replace
    #     it with "hip", AND
    #   - If the file name contains "CUDA", replace it with "HIP", AND
    #   - ALWAYS replace '.cu' with '.hip', because those files
    #     contain CUDA kernels that needs to be hipified and processed with
    #     hip compiler
    #   - If we are not hipifying a PyTorch extension, and the parent
    #     directory name did not change as a result of the above
    #     transformations, insert "hip" in the file path
    #     as the direct parent folder of the file
    #   - If we are hipifying a PyTorch extension, and the parent directory
    #     name as well as the filename (incl. extension) did not change as
    #     a result of the above transformations, insert "_hip" in the filename
    # This isn't set in stone; we might adjust this to support other
    # naming conventions.

    if ext == '.cu':
        ext = '.hip'

    orig_filename = filename
    orig_dirpath = dirpath

    dirpath = dirpath.replace('cuda', 'hip')
    dirpath = dirpath.replace('CUDA', 'HIP')
    dirpath = dirpath.replace('THC', 'THH')

    root = root.replace('cuda', 'hip')
    root = root.replace('CUDA', 'HIP')
    # Special case to handle caffe2/core/THCCachingAllocator
    if dirpath != "caffe2/core":
        root = root.replace('THC', 'THH')

    if not is_pytorch_extension and dirpath == orig_dirpath:
        dirpath = os.path.join(dirpath, 'hip')

    if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename:
        root = root + "_hip"

    return os.path.join(dirpath, root + ext)

def is_out_of_place(rel_filepath):
    assert not os.path.isabs(rel_filepath)
    if rel_filepath.startswith("torch/"):
        return False
    if rel_filepath.startswith("third_party/nvfuser/"):
        return False
    if rel_filepath.startswith("tools/autograd/templates/"):
        return False
    return True

# Keep this synchronized with includes/ignores in build_amd.py
def is_pytorch_file(rel_filepath):
    assert not os.path.isabs(rel_filepath)
    if rel_filepath.startswith("aten/"):
        if rel_filepath.startswith("aten/src/ATen/core/"):
            return False
        return True
    if rel_filepath.startswith("torch/"):
        return True
    if rel_filepath.startswith("third_party/nvfuser/"):
        return True
    if rel_filepath.startswith("tools/autograd/templates/"):
        return True
    return False

def is_cusparse_file(rel_filepath):
    if is_pytorch_file(rel_filepath):
        return "sparse" in rel_filepath.lower()
    return False

def is_caffe2_gpu_file(rel_filepath):
    assert not os.path.isabs(rel_filepath)
    if rel_filepath.startswith("c10/cuda"):
        return True
    filename = os.path.basename(rel_filepath)
    _, ext = os.path.splitext(filename)
    return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)

# Cribbed from https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
class Trie():
    """Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
    The corresponding Regex should match much faster than a simple Regex union."""

    def __init__(self):
        self.data = {}

    def add(self, word):
        ref = self.data
        for char in word:
            ref[char] = char in ref and ref[char] or {}
            ref = ref[char]
        ref[''] = 1

    def dump(self):
        return self.data

    def quote(self, char):
        return re.escape(char)

    def _pattern(self, pData):
        data = pData
        if "" in data and len(data.keys()) == 1:
            return None

        alt = []
        cc = []
        q = 0
        for char in sorted(data.keys()):
            if isinstance(data[char], dict):
                    recurse = self._pattern(data[char])
                    alt.append(self.quote(char) + recurse)
                except Exception:
                q = 1
        cconly = not len(alt) > 0

        if len(cc) > 0:
            if len(cc) == 1:
                alt.append('[' + ''.join(cc) + ']')

        if len(alt) == 1:
            result = alt[0]
            result = "(?:" + "|".join(alt) + ")"

        if q:
            if cconly:
                result += "?"
                result = "(?:%s)?" % result
        return result

    def pattern(self):
        return self._pattern(self.dump())

CAFFE2_TRIE = Trie()
PYTORCH_MAP: Dict[str, object] = {}

# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip.
# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance.
# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex.
# In the case of SPARSE, we must use the hip types for complex instead of the roc types,
# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.

for mapping in CUDA_TO_HIP_MAPPINGS:
    assert isinstance(mapping, Mapping)
    for src, value in mapping.items():
        dst = value[0]
        meta_data = value[1:]
        if constants.API_CAFFE2 not in meta_data:
            # if src is already in PYTORCH_MAP and dst belongs to API_SPARSE
            # do not overwrite PYTORCH_MAP, store dst separately
            if constants.API_SPARSE in meta_data and PYTORCH_MAP.get(src, ""):
                PYTORCH_SPARSE_MAP[src] = dst
                PYTORCH_MAP[src] = dst
        if constants.API_PYTORCH not in meta_data:
            CAFFE2_MAP[src] = dst
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
RE_PYTORCH_PREPROCESSOR = re.compile(r'(?<=\W)({0})(?=\W)'.format(PYTORCH_TRIE.pattern()))

RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
RE_CU_SUFFIX = re.compile(r'\.cu\b')  # be careful not to pick up .cuh

Returns a dict with the following keys:
    "hipified_path" : absolute path of hipified source file
    "status"        : "ok"      if hipified file was written out
                      "skipped" if an identical hipified file already existed or hipified file couldn't be written out
                      "ignored" if the source file was a hipified file itself or not meant to be hipified

def preprocessor(
        output_directory: str,
        filepath: str,
        all_files: Iterable,
        header_include_dirs: Iterable,
        stats: Dict[str, List],
        hip_clang_launch: bool,
        is_pytorch_extension: bool,
        clean_ctx: GeneratedFileCleaner,
        show_progress: bool) -> HipifyResult:
    """ Executes the CUDA -> HIP conversion on the specified file. """
    if filepath not in all_files:
        return {"hipified_path": None, "status": "[ignored, not to be hipified]"}

    fin_path = os.path.abspath(os.path.join(output_directory, filepath))
    rel_filepath = os.path.relpath(filepath, output_directory)

    with open(fin_path, 'r', encoding='utf-8') as fin:
        if fin.readline() == HIPIFY_C_BREADCRUMB:
            return {"hipified_path": None, "status": "[ignored, input is hipified output]"}
        output_source = fin.read()

    orig_output_source = output_source

    # get_hip_file_path needs a relative path to work correctly
    fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension)))
    if not os.path.exists(os.path.dirname(fout_path)):

    # unsupported_calls statistics reporting is broken atm
    def pt_repl(m):
        return PYTORCH_MAP[m.group(0)]

    def pt_sparse_repl(m):
        # checks SPARSE map first, and if a miss occurs, falls back to pytorch mappings
        return PYTORCH_SPARSE_MAP.get(m.group(0), pt_repl(m))

    if is_pytorch_extension:
        output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
        if is_cusparse_file(rel_filepath):
            output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_sparse_repl, output_source)
        elif is_pytorch_file(rel_filepath):
            output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
            def c2_repl(m):
                return CAFFE2_MAP[m.group(0)]
            output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)

    # Header rewrites
    def mk_repl(templ, include_current_dir=True):
        def repl(m):
            f = m.group(1)
            dirpath, filename = os.path.split(f)
            if (
                or f.startswith("ATen/native/cuda")
                or f.startswith("ATen/native/nested/cuda")
                or f.startswith("ATen/native/quantized/cuda")
                or f.startswith("ATen/native/sparse/cuda")
                or f.startswith("ATen/native/transformers/cuda")
                or f.startswith("THC/")
                or (f.startswith("THC") and not f.startswith("THCP"))
                return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
            # if filename is one of the files being hipified for this extension
            if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)):
                header_dir = None
                header_filepath = None
                # If include_current_dir True, look first in same dir as the including source file
                if include_current_dir:
                    header_dir_to_check = os.path.dirname(fin_path)
                    header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
                    if os.path.exists(header_path_to_check):
                        header_dir = header_dir_to_check
                        header_filepath = header_path_to_check
                # If not found, look in include dirs one by one and first match wins
                if header_filepath is None:
                    for header_include_dir in header_include_dirs:
                        header_dir_to_check = os.path.join(output_directory, header_include_dir)
                        header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
                        if os.path.exists(header_path_to_check):
                            header_dir = header_dir_to_check
                            header_filepath = header_path_to_check
                # If header file not found, keep as is
                if header_filepath is None:
                    return m.group(0)
                # Hipify header file first if needed
                if header_filepath not in HIPIFY_FINAL_RESULT:
                                                    all_files, header_include_dirs, stats, hip_clang_launch,
                                                    is_pytorch_extension, clean_ctx, show_progress)
                hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath]["hipified_path"]
                return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
                                                    else header_filepath, header_dir))

            return m.group(0)
        return repl
    output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source)
    output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source)
    output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source)

    # CMakeLists.txt rewrites
    if filepath.endswith('CMakeLists.txt'):
        output_source = output_source.replace('CUDA', 'HIP')
        output_source = output_source.replace('THC', 'THH')
        output_source = RE_CU_SUFFIX.sub('.hip', output_source)

    # Perform Kernel Launch Replacements
    if not hip_clang_launch:
        output_source = processKernelLaunches(output_source, stats)

    # Replace std:: with non-std:: versions
    if (filepath.endswith(".cu") or filepath.endswith(".cuh")) and "PowKernel" not in filepath:
        output_source = replace_math_functions(output_source)

    # Include header if device code is contained.
    output_source = hip_header_magic(output_source)

    # Replace the extern __shared__
    # NOTE: No longer needed after transition from hcc to hipclang.
    # output_source = replace_extern_shared(output_source)

    # Don't write out identical hipified files for extensions if dirpath has not changed
    if (
        and orig_output_source == output_source
        and os.path.dirname(fin_path) == os.path.dirname(fout_path)
        return {"hipified_path": fin_path, "status": "[skipped, no changes]"}

    # Add hipify breadcrumb for C-style files to avoid re-hipification
    if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
        output_source = HIPIFY_C_BREADCRUMB + output_source

    do_write = True
    if os.path.exists(fout_path):
        with open(fout_path, 'r', encoding='utf-8') as fout_old:
            do_write = fout_old.read() != output_source
    if do_write:
            with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
            return {"hipified_path": fout_path, "status": "[ok]"}
        except PermissionError as e:
            print(f"{bcolors.WARNING}Failed to save {fout_path} with \"{e.strerror}\", leaving {fin_path} unchanged.{bcolors.ENDC}",
            return {"hipified_path": fin_path, "status": "[skipped, no permissions]"}
        return {"hipified_path": fout_path, "status": "[skipped, already hipified]"}

def file_specific_replacement(filepath, search_string, replace_string, strict=False):
    with openf(filepath, "r+") as f:
        contents = f.read()
        if strict:
            contents = re.sub(r'\b({0})\b'.format(re.escape(search_string)), lambda x: replace_string, contents)
            contents = contents.replace(search_string, replace_string)

def file_add_header(filepath, header):
    with openf(filepath, "r+") as f:
        contents = f.read()
        if header[0] != "<" and header[-1] != ">":
            header = '"{0}"'.format(header)
        contents = ('#include {0} \n'.format(header)) + contents

def fix_static_global_kernels(in_txt):
    """Static global kernels in HIP results in a compilation error."""
    in_txt = in_txt.replace(" __global__ static", "__global__")
    return in_txt

RE_INCLUDE = re.compile(r"#include .*\n")

def extract_arguments(start, string):
    """ Return the list of arguments in the upcoming function parameter closure.
        string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
        arguments (output):
            '[{'start': 1, 'end': 7},
            {'start': 8, 'end': 16},
            {'start': 17, 'end': 19},
            {'start': 20, 'end': 53}]'

    arguments = []
    closures = {
        "<": 0,
        "(": 0
    current_position = start
    argument_start_pos = current_position + 1

    # Search for final parenthesis
    while current_position < len(string):
        if string[current_position] == "(":
            closures["("] += 1
        elif string[current_position] == ")":
            closures["("] -= 1
        elif string[current_position] == "<":
            closures["<"] += 1
        elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0:
            closures["<"] -= 1

        # Finished all arguments
        if closures["("] == 0 and closures["<"] == 0:
            # Add final argument
            arguments.append({"start": argument_start_pos, "end": current_position})

        # Finished current argument
        if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",":
            arguments.append({"start": argument_start_pos, "end": current_position})
            argument_start_pos = current_position + 1

        current_position += 1

    return arguments

def str2bool(v):
    """ArgumentParser doesn't support type=bool. Thus, this helper method will convert
    from possible string types to True / False."""
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
        raise argparse.ArgumentTypeError('Boolean value expected.')

def hipify(
    project_directory: str,
    show_detailed: bool = False,
    extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
    header_extensions: Iterable = (".cuh", ".h", ".hpp"),
    output_directory: str = "",
    header_include_dirs: Iterable = (),
    includes: Iterable = ('*',),
    extra_files: Iterable = (),
    out_of_place_only: bool = False,
    ignores: Iterable = (),
    show_progress: bool = True,
    hip_clang_launch: bool = False,
    is_pytorch_extension: bool = False,
    hipify_extra_files_only: bool = False,
    clean_ctx: Optional[GeneratedFileCleaner] = None
) -> HipifyFinalResult:
    if project_directory == "":
        project_directory = os.getcwd()

    # Verify the project directory exists.
    if not os.path.exists(project_directory):
        print("The project folder specified does not exist.")

    # If no output directory, provide a default one.
    if not output_directory:
        output_directory = project_directory + "_amd"

    if project_directory != output_directory:
        includes = [include.replace(project_directory, output_directory) for include in includes]
        ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores]

    # Copy from project directory to output directory if not done already.
    if not os.path.exists(output_directory):
        shutil.copytree(project_directory, output_directory)

    all_files = list(matched_files_iter(output_directory, includes=includes,
                                        ignores=ignores, extensions=extensions,
    all_files_set = set(all_files)
    for f in extra_files:
        if not os.path.isabs(f):
            f = os.path.join(output_directory, f)
        if f not in all_files_set:

    # List all files in header_include_paths to ensure they are hipified
    from pathlib import Path
    for header_include_dir in header_include_dirs:
        if os.path.isabs(header_include_dir):
            header_include_dir_path = Path(header_include_dir)
            header_include_dir_path = Path(os.path.join(output_directory, header_include_dir))
        for path in header_include_dir_path.rglob('*'):
            if (
                and _fnmatch(str(path), includes)
                and (not _fnmatch(str(path), ignores))
                and match_extensions(path.name, header_extensions)

    if clean_ctx is None:
        clean_ctx = GeneratedFileCleaner(keep_intermediates=True)

    # Preprocessing statistics.
    stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}

    for filepath in (all_files if not hipify_extra_files_only else extra_files):
        preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs,
                                        stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)

    print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)

    # Show detailed summary
    if show_detailed: