Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ utils / hipify / hipify_python.py

#!/usr/bin/env python
""" The Python Hipify script.
##
# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
#               2017-2018 Advanced Micro Devices, Inc. and
#                         Facebook Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
"""
import argparse
import fnmatch
import re
import shutil
import sys
import os

from . import constants
from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
from .cuda_to_hip_mappings import MATH_TRANSPILATIONS

from typing import Dict, List, Iterator, Optional
from collections.abc import Mapping, Iterable
HipifyResult = Dict[str, Optional[str]]
HipifyFinalResult = Dict[str, HipifyResult]
HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
HIPIFY_FINAL_RESULT: HipifyFinalResult = {}

# Hardcode the PyTorch template map
"""This dictionary provides the mapping from PyTorch kernel template types
to their actual types."""
PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}


class InputError(Exception):
    # Exception raised for errors in the input.

    def __init__(self, message):
        super(InputError, self).__init__(message)
        self.message = message

    def __str__(self):
        return "{}: {}".format("Input error", self.message)


def openf(filename, mode):
    return open(filename, mode, errors='ignore')


# Color coding for printing
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

# To the programmer, the output of hipify most likely are intermediates.
# This class allows users of hipify to ask for a cleanup by running the
# hipify and compilation in a with instantiating this context manager class
# with keep_intermediates=False.
# The main usecase is the cpp_extensions, specifically the load method.
# It is a good idea to keep intermediates (in case of errors or to
# not recompile unchanged files), but in cases where you don't want to
# keep them (e.g. in the CI), this can be used to remove files.
class GeneratedFileCleaner:
    """Context Manager to clean up generated files"""
    def __init__(self, keep_intermediates=False):
        self.keep_intermediates = keep_intermediates
        self.files_to_clean = set()
        self.dirs_to_clean = []

    def __enter__(self):
        return self

    def open(self, fn, *args, **kwargs):
        if not os.path.exists(fn):
            self.files_to_clean.add(os.path.abspath(fn))
        return open(fn, *args, **kwargs)

    def makedirs(self, dn, exist_ok=False):
        parent, n = os.path.split(dn)
        if not n:
            parent, n = os.path.split(parent)
        if parent and n and not os.path.exists(parent):
            self.makedirs(parent, exist_ok=True)
        if not os.path.isdir(dn) or not exist_ok:
            os.mkdir(dn)
            self.dirs_to_clean.append(os.path.abspath(dn))

    def __exit__(self, type, value, traceback):
        if not self.keep_intermediates:
            for f in self.files_to_clean:
                os.unlink(f)
            for d in self.dirs_to_clean[::-1]:
                os.rmdir(d)

def match_extensions(filename: str, extensions: Iterable) -> bool:
    """Helper method to see if filename ends with certain extension"""
    return any(filename.endswith(e) for e in extensions)

def matched_files_iter(
        root_path: str,
        includes: Iterable = ('*',),
        ignores: Iterable = (),
        extensions: Iterable = (),
        out_of_place_only: bool = False,
        is_pytorch_extension: bool = False) -> Iterator[str]:
    def _fnmatch(filepath, patterns):
        return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)

    exact_matches = set(includes)

    # This is a very rough heuristic; really, we want to avoid scanning
    # any file which is not checked into source control, but this script
    # needs to work even if you're in a Git or Hg checkout, so easier to
    # just block the biggest time sinks that won't matter in the
    # end.
    for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True):
        rel_dirpath = os.path.relpath(abs_dirpath, root_path)
        if rel_dirpath == '.':
            # Blah blah blah O(n) blah blah
            if ".git" in dirs:
                dirs.remove(".git")
            if "build" in dirs:
                dirs.remove("build")
            if "third_party" in dirs:
                dirs.remove("third_party")
        for filename in filenames:
            filepath = os.path.join(rel_dirpath, filename)
            # We respect extensions, UNLESS you wrote the entire
            # filename verbatim, in which case we always accept it
            if (
                _fnmatch(filepath, includes)
                and (not _fnmatch(filepath, ignores))
                and (match_extensions(filepath, extensions) or filepath in exact_matches)
            ):
                if not is_pytorch_extension:  # for pytorch extensions, consider all files
                    if not is_pytorch_file(filepath) and not is_caffe2_gpu_file(filepath):
                        continue
                    if out_of_place_only and not is_out_of_place(filepath):
                        continue
                yield filepath


def preprocess_file_and_save_result(
        output_directory: str,
        filepath: str,
        all_files: Iterable,
        includes: Iterable,
        stats: Dict[str, List],
        hip_clang_launch: bool,
        is_pytorch_extension: bool,
        clean_ctx: GeneratedFileCleaner,
        show_progress: bool) -> None:
    result = preprocessor(output_directory, filepath, all_files, includes, stats,
                          hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)

    fin_path = os.path.abspath(os.path.join(output_directory, filepath))
    # Show what happened
    if show_progress:
        print(
            fin_path, "->",
            result["hipified_path"], result["status"])

    if result["hipified_path"] is not None:
        HIPIFY_FINAL_RESULT[fin_path] = result


def preprocess(
        output_directory: str,
        all_files: Iterable,
        includes: Iterable,
        show_detailed: bool = False,
        show_progress: bool = True,
        hip_clang_launch: bool = False,
        is_pytorch_extension: bool = False,
        clean_ctx: GeneratedFileCleaner = None) -> HipifyFinalResult:
    """
    Call preprocessor on selected files.

    Arguments)
        show_detailed - Show a detailed summary of the transpilation process.
    """

    if clean_ctx is None:
        clean_ctx = GeneratedFileCleaner(keep_intermediates=True)

    # Preprocessing statistics.
    stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}

    for filepath in all_files:
        preprocess_file_and_save_result(output_directory, filepath, all_files, includes, stats,
                                        hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)

    print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)

    # Show detailed summary
    if show_detailed:
        compute_stats(stats)

    return HIPIFY_FINAL_RESULT


def compute_stats(stats):
    unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}

    # Print the number of unsupported calls
    print("Total number of unsupported CUDA function calls: {0:d}".format(len(unsupported_calls)))

    # Print the list of unsupported calls
    print(", ".join(unsupported_calls))

    # Print the number of kernel launches
    print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"])))


def add_dim3(kernel_string, cuda_kernel):
    '''adds dim3() to the second and third arguments in the kernel launch'''
    count = 0
    closure = 0
    kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
    arg_locs: List[Dict[str, int]] = [{} for _ in range(2)]
    arg_locs[count]['start'] = 0
    for ind, c in enumerate(kernel_string):
        if count > 1:
            break
        if c == "(":
            closure += 1
        elif c == ")":
            closure -= 1
        elif (c == "," or ind == len(kernel_string) - 1) and closure == 0:
            arg_locs[count]['end'] = ind + (c != ",")
            count += 1
            if count < 2:
                arg_locs[count]['start'] = ind + 1

    first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1]
    second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']]

    first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
    second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")

    first_arg_dim3 = "dim3({})".format(first_arg_clean)
    second_arg_dim3 = "dim3({})".format(second_arg_clean)

    first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
    second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
    cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
    return cuda_kernel


RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')


def processKernelLaunches(string, stats):
    """ Replace the CUDA style Kernel launches with the HIP style kernel launches."""
    # Concat the namespace with the kernel names. (Find cleaner way of doing this later).
    string = RE_KERNEL_LAUNCH.sub(lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)

    def grab_method_and_template(in_kernel):
        # The positions for relevant kernel components.
        pos = {
            "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]},
            "kernel_name": {"start": -1, "end": -1},
            "template": {"start": -1, "end": -1}
        }

        # Count for balancing template
        count = {"<>": 0}

        # Status for whether we are parsing a certain item.
        START = 0
        AT_TEMPLATE = 1
        AFTER_TEMPLATE = 2
        AT_KERNEL_NAME = 3

        status = START

        # Parse the string character by character
        for i in range(pos["kernel_launch"]["start"] - 1, -1, -1):
            char = string[i]

            # Handle Templating Arguments
            if status == START or status == AT_TEMPLATE:
                if char == ">":
                    if status == START:
                        status = AT_TEMPLATE
                        pos["template"]["end"] = i
                    count["<>"] += 1

                if char == "<":
                    count["<>"] -= 1
                    if count["<>"] == 0 and (status == AT_TEMPLATE):
                        pos["template"]["start"] = i
                        status = AFTER_TEMPLATE

            # Handle Kernel Name
            if status != AT_TEMPLATE:
                if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}:
                    if status != AT_KERNEL_NAME:
                        status = AT_KERNEL_NAME
                        pos["kernel_name"]["end"] = i

                    # Case: Kernel name starts the string.
                    if i == 0:
                        pos["kernel_name"]["start"] = 0

                        # Finished
                        return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]

                else:
                    # Potential ending point if we're already traversing a kernel's name.
                    if status == AT_KERNEL_NAME:
                        pos["kernel_name"]["start"] = i

                        # Finished
                        return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]

    def find_kernel_bounds(string):
        """Finds the starting and ending points for all kernel launches in the string."""
        kernel_end = 0
        kernel_positions = []

        # Continue until we cannot find any more kernels anymore.
        while string.find("<<<", kernel_end) != -1:
            # Get kernel starting position (starting from the previous ending point)
Loading ...