Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ testing / check_kernel_launches.py

import os
import re
import sys


# Regular expression identifies a kernel launch indicator by
# finding something approximating the pattern ">>>(arguments);"
# It then requires that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
# the next command.
# It allows a single backslash `\` between the end of the launch
# command and the beginning of the kernel check. This handles
# cases where the kernel launch is in a multiline preprocessor
# definition.
#
# There are various ways this can fail:
# * If the semicolon is in a string for some reason
# * If there's a triply-nested template
# But this should be sufficient to detect and fix most problem
# instances and can be refined before the test is made binding
kernel_launch_regex = re.compile(r"""
    ^.*>>>        # Identifies kernel launch
    \s*           # Maybe some whitespace (includes newlines)
    \([^;]+\);    # And then arguments in parens and semi-colon
    (?!           # Negative lookahead: we trigger if we don't find the launch guard
        \s*                                  # Maybe some whitespace (includes newlines)
        \\?                                  # 0 or 1 backslashes (for launches in preprocessor macros)
        \s*                                  # Maybe some whitespace (includes newlines)
        (?:[0-9]+: )?                        # Detects and ignores a line numbering, if present
        \s*                                  # Maybe some whitespace (includes newlines)
        C10_CUDA_KERNEL_LAUNCH_CHECK\(\);  # Kernel launch guard!
    )             # End negative lookahead
""", flags=re.MULTILINE | re.VERBOSE)


def check_code_for_cuda_kernel_launches(code, filename=None):
    """Checks code for CUDA kernel launches without cuda error checks.

    Args:
        filename - Filename of file containing the code. Used only for display
                   purposes, so you can put anything here.
        code     - The code to check

    Returns:
        The number of unsafe kernel launches in the code
    """
    if filename is None:
        filename = "##Python Function Call##"

    # We break the code apart and put it back together to add
    # helpful line numberings for identifying problem areas
    code = enumerate(code.split("\n"))                             # Split by line breaks
    code = [f"{lineno}: {linecode}" for lineno, linecode in code]  # Number the lines
    code = '\n'.join(code)                                         # Put it back together

    results = kernel_launch_regex.findall(code)               # Search for bad launches
    for r in results:
        print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{r}", file=sys.stderr)
    return len(results)


def check_file(filename):
    """Checks a file for CUDA kernel launches without cuda error checks

    Args:
        filename - File to check

    Returns:
        The number of unsafe kernel launches in the file
    """
    if not (filename.endswith(".cu") or filename.endswith(".cuh")):
        return 0
    contents = open(filename, "r").read()
    return check_code_for_cuda_kernel_launches(contents, filename)


def check_cuda_kernel_launches():
    """Checks all pytorch code for CUDA kernel launches without cuda error checks

    Returns:
        The number of unsafe kernel launches in the codebase
    """
    torch_dir = os.path.dirname(os.path.realpath(__file__))
    torch_dir = os.path.dirname(torch_dir)  # Go up to parent torch
    torch_dir = os.path.dirname(torch_dir)  # Go up to parent caffe2

    kernels_without_checks = 0
    files_without_checks = []
    for root, dirnames, filenames in os.walk(torch_dir):
        # `$BASE/build` and `$BASE/torch/include` are generated
        # so we don't want to flag their contents
        if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
            # Curtail search by modifying dirnames and filenames in place
            # Yes, this is the way to do this, see `help(os.walk)`
            dirnames[:] = []
            continue

        for x in filenames:
            filename = os.path.join(root, x)
            file_result = check_file(filename)
            if file_result > 0:
                kernels_without_checks += file_result
                files_without_checks.append(filename)

    if kernels_without_checks > 0:
        count_str = f"Found {kernels_without_checks} instances in " \
                    f"{len(files_without_checks)} files where kernel " \
                    "launches didn't have checks."
        print(count_str, file=sys.stderr)
        print("Files without checks:", file=sys.stderr)
        for x in files_without_checks:
            print(f"\t{x}", file=sys.stderr)
        print(count_str, file=sys.stderr)

    return kernels_without_checks


if __name__ == "__main__":
    unsafe_launches = check_cuda_kernel_launches()
    sys.exit(0)