import copy
import glob
import imp
import os
import re
import shlex
import setuptools
import subprocess
import sys
import sysconfig
import warnings
import collections
import torch
import torch._appdirs
from .file_baton import FileBaton
from ._cpp_extension_versioner import ExtensionVersioner
from .hipify import hipify_python
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
from typing import List, Optional, Union
from setuptools.command.build_ext import build_ext
from pkg_resources import packaging # type: ignore
IS_WINDOWS = sys.platform == 'win32'
LIB_EXT = '.pyd' if IS_WINDOWS else '.so'
EXEC_EXT = '.exe' if IS_WINDOWS else ''
CLIB_PREFIX = '' if IS_WINDOWS else 'lib'
CLIB_EXT = '.dll' if IS_WINDOWS else '.so'
SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared'
_HERE = os.path.abspath(__file__)
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib')
BUILD_SPLIT_CUDA = os.getenv('BUILD_SPLIT_CUDA') or (os.path.exists(os.path.join(
TORCH_LIB_PATH, f'{CLIB_PREFIX}torch_cuda_cu{CLIB_EXT}')) and os.path.exists(os.path.join(TORCH_LIB_PATH, f'{CLIB_PREFIX}torch_cuda_cpp{CLIB_EXT}')))
# Taken directly from python stdlib < 3.9
# See https://github.com/pytorch/pytorch/issues/48617
def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
"""Quote command-line arguments for DOS/Windows conventions.
Just wraps every argument which contains blanks in double quotes, and
returns a new argument list.
"""
# Cover None-type
if not args:
return []
return [f'"{arg}"' if ' ' in arg else arg for arg in args]
def _find_cuda_home() -> Optional[str]:
r'''Finds the CUDA install path.'''
# Guess #1
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
if cuda_home is None:
# Guess #2
try:
which = 'where' if IS_WINDOWS else 'which'
with open(os.devnull, 'w') as devnull:
nvcc = subprocess.check_output([which, 'nvcc'],
stderr=devnull).decode().rstrip('\r\n')
cuda_home = os.path.dirname(os.path.dirname(nvcc))
except Exception:
# Guess #3
if IS_WINDOWS:
cuda_homes = glob.glob(
'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
if len(cuda_homes) == 0:
cuda_home = ''
else:
cuda_home = cuda_homes[0]
else:
cuda_home = '/usr/local/cuda'
if not os.path.exists(cuda_home):
cuda_home = None
if cuda_home and not torch.cuda.is_available():
print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'")
return cuda_home
def _find_rocm_home() -> Optional[str]:
r'''Finds the ROCm install path.'''
# Guess #1
rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
if rocm_home is None:
# Guess #2
try:
hipcc = subprocess.check_output(
['which', 'hipcc'], stderr=subprocess.DEVNULL).decode().rstrip('\r\n')
# this will be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
rocm_home = os.path.dirname(os.path.dirname(hipcc))
if os.path.basename(rocm_home) == 'hip':
rocm_home = os.path.dirname(rocm_home)
except Exception:
# Guess #3
rocm_home = '/opt/rocm'
if not os.path.exists(rocm_home):
rocm_home = None
if rocm_home and torch.version.hip is None:
print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'")
return rocm_home
def _join_rocm_home(*paths) -> str:
r'''
Joins paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
This is basically a lazy way of raising an error for missing $ROCM_HOME
only once we need to get any ROCm-specific path.
'''
if ROCM_HOME is None:
raise EnvironmentError('ROCM_HOME environment variable is not set. '
'Please set it to your ROCm install root.')
elif IS_WINDOWS:
raise EnvironmentError('Building PyTorch extensions using '
'ROCm and Windows is not supported.')
return os.path.join(ROCM_HOME, *paths)
MINIMUM_GCC_VERSION = (5, 0, 0)
MINIMUM_MSVC_VERSION = (19, 0, 24215)
ABI_INCOMPATIBILITY_WARNING = '''
!! WARNING !!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Your compiler ({}) may be ABI-incompatible with PyTorch!
Please use a compiler that is ABI-compatible with GCC 5.0 and above.
See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.
See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6
for instructions on how to install GCC 5 or higher.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!! WARNING !!
'''
WRONG_COMPILER_WARNING = '''
!! WARNING !!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was
built with for this platform, which is {pytorch_compiler} on {platform}. Please
use {pytorch_compiler} to to compile your extension. Alternatively, you may
compile PyTorch from source using {user_compiler}, and then you can also use
{user_compiler} to compile your extension.
See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
with compiling PyTorch from source.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!! WARNING !!
'''
ROCM_HOME = _find_rocm_home()
MIOPEN_HOME = _join_rocm_home('miopen') if ROCM_HOME else None
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
ROCM_VERSION = None
if torch.version.hip is not None:
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
CUDA_HOME = _find_cuda_home()
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
# PyTorch releases have the version pattern major.minor.patch, whereas when
# PyTorch is built from source, we append the git commit hash, which gives
# it the below pattern.
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
COMMON_MSVC_FLAGS = ['/MD', '/wd4819', '/wd4251', '/wd4244', '/wd4267', '/wd4275', '/wd4018', '/wd4190', '/EHsc']
MSVC_IGNORE_CUDAFE_WARNINGS = [
'base_class_has_different_dll_interface',
'field_without_dll_interface',
'dll_interface_conflict_none_assumed',
'dll_interface_conflict_dllexport_assumed'
]
COMMON_NVCC_FLAGS = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'--expt-relaxed-constexpr'
]
COMMON_HIP_FLAGS = [
'-fPIC',
'-D__HIP_PLATFORM_HCC__=1',
]
COMMON_HIPCC_FLAGS = [
'-DCUDA_HAS_FP16=1',
'-D__HIP_NO_HALF_OPERATORS__=1',
'-D__HIP_NO_HALF_CONVERSIONS__=1',
]
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
PLAT_TO_VCVARS = {
'win32' : 'x86',
'win-amd64' : 'x86_amd64',
}
def _is_binary_build() -> bool:
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
def _accepted_compilers_for_platform() -> List[str]:
# gnu-c++ and gnu-cc are the conda gcc compilers
return ['clang++', 'clang'] if sys.platform.startswith('darwin') else ['g++', 'gcc', 'gnu-c++', 'gnu-cc']
def get_default_build_root() -> str:
r'''
Returns the path to the root folder under which extensions will built.
For each extension module built, there will be one folder underneath the
folder returned by this function. For example, if ``p`` is the path
returned by this function and ``ext`` the name of an extension, the build
folder for the extension will be ``p/ext``.
This directory is **user-specific** so that multiple users on the same
machine won't meet permission issues.
'''
return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions'))
def check_compiler_ok_for_platform(compiler: str) -> bool:
r'''
Verifies that the compiler is the expected one for the current platform.
Args:
compiler (str): The compiler executable to check.
Returns:
True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
and always True for Windows.
'''
if IS_WINDOWS:
return True
which = subprocess.check_output(['which', compiler], stderr=subprocess.STDOUT)
# Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.
compiler_path = os.path.realpath(which.decode().strip())
# Check the compiler name
if any(name in compiler_path for name in _accepted_compilers_for_platform()):
return True
# If ccache is used the compiler path is /usr/bin/ccache. Check by -v flag.
version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT).decode()
if sys.platform.startswith('linux'):
# Check for 'gcc' or 'g++'
pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
results = re.findall(pattern, version_string)
if len(results) != 1:
return False
compiler_path = os.path.realpath(results[0].strip())
return any(name in compiler_path for name in _accepted_compilers_for_platform())
if sys.platform.startswith('darwin'):
# Check for 'clang' or 'clang++'
return version_string.startswith("Apple clang")
return False
def check_compiler_abi_compatibility(compiler) -> bool:
r'''
Verifies that the given compiler is ABI-compatible with PyTorch.
Args:
compiler (str): The compiler executable name to check (e.g. ``g++``).
Must be executable in a shell process.
Returns:
False if the compiler is (likely) ABI-incompatible with PyTorch,
else True.
'''
if not _is_binary_build():
return True
if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']:
return True
# First check if the compiler is one of the expected ones for the particular platform.
if not check_compiler_ok_for_platform(compiler):
warnings.warn(WRONG_COMPILER_WARNING.format(
user_compiler=compiler,
pytorch_compiler=_accepted_compilers_for_platform()[0],
platform=sys.platform))
return False
if sys.platform.startswith('darwin'):
# There is no particular minimum version we need for clang, so we're good here.
return True
try:
if sys.platform.startswith('linux'):
minimum_required_version = MINIMUM_GCC_VERSION
versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
version = versionstr.decode().strip().split('.')
else:
minimum_required_version = MINIMUM_MSVC_VERSION
compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip())
version = (0, 0, 0) if match is None else match.groups()
except Exception:
_, error, _ = sys.exc_info()
warnings.warn(f'Error checking compiler version for {compiler}: {error}')
return False
if tuple(map(int, version)) >= minimum_required_version:
return True
compiler = f'{compiler} {".".join(version)}'
warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))
return False
# See below for why we inherit BuildExtension from object.
# https://stackoverflow.com/questions/1713038/super-fails-with-error-typeerror-argument-1-must-be-type-not-classobj-when
class BuildExtension(build_ext, object):
r'''
A custom :mod:`setuptools` build extension .
This :class:`setuptools.build_ext` subclass takes care of passing the
minimum required compiler flags (e.g. ``-std=c++14``) as well as mixed
C++/CUDA compilation (and support for CUDA files in general).
When using :class:`BuildExtension`, it is allowed to supply a dictionary
for ``extra_compile_args`` (rather than the usual list) that maps from
languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to
supply to the compiler. This makes it possible to supply different flags to
the C++ and CUDA compiler during mixed compilation.
``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we
attempt to build using the Ninja backend. Ninja greatly speeds up
compilation compared to the standard ``setuptools.build_ext``.
Fallbacks to the standard distutils backend if Ninja is not available.
.. note::
By default, the Ninja backend uses #CPUS + 2 workers to build the
extension. This may use up too many resources on some systems. One
can control the number of workers by setting the `MAX_JOBS` environment
variable to a non-negative number.
'''
Loading ...