import copy
import glob
import imp
import os
import re
import shlex
import setuptools
import subprocess
import sys
import sysconfig
import warnings
import collections

import torch
import torch._appdirs
from .file_baton import FileBaton
from ._cpp_extension_versioner import ExtensionVersioner
from .hipify import hipify_python
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
from typing import List, Optional, Union

from setuptools.command.build_ext import build_ext
from pkg_resources import packaging  # type: ignore

IS_WINDOWS = sys.platform == 'win32'
LIB_EXT = '.pyd' if IS_WINDOWS else '.so'
EXEC_EXT = '.exe' if IS_WINDOWS else ''
CLIB_PREFIX = '' if IS_WINDOWS else 'lib'
CLIB_EXT = '.dll' if IS_WINDOWS else '.so'
SHARED_FLAG = '/DLL' if IS_WINDOWS else '-shared'

_HERE = os.path.abspath(__file__)
_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE))
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, 'lib')

BUILD_SPLIT_CUDA = os.getenv('BUILD_SPLIT_CUDA') or (os.path.exists(os.path.join(
    TORCH_LIB_PATH, f'{CLIB_PREFIX}torch_cuda_cu{CLIB_EXT}')) and os.path.exists(os.path.join(TORCH_LIB_PATH, f'{CLIB_PREFIX}torch_cuda_cpp{CLIB_EXT}')))

# Taken directly from python stdlib < 3.9
# See https://github.com/pytorch/pytorch/issues/48617
def _nt_quote_args(args: Optional[List[str]]) -> List[str]:
    """Quote command-line arguments for DOS/Windows conventions.

    Just wraps every argument which contains blanks in double quotes, and
    returns a new argument list.
    # Cover None-type
    if not args:
        return []
    return [f'"{arg}"' if ' ' in arg else arg for arg in args]

def _find_cuda_home() -> Optional[str]:
    r'''Finds the CUDA install path.'''
    # Guess #1
    cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
    if cuda_home is None:
        # Guess #2
            which = 'where' if IS_WINDOWS else 'which'
            with open(os.devnull, 'w') as devnull:
                nvcc = subprocess.check_output([which, 'nvcc'],
                cuda_home = os.path.dirname(os.path.dirname(nvcc))
        except Exception:
            # Guess #3
            if IS_WINDOWS:
                cuda_homes = glob.glob(
                    'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
                if len(cuda_homes) == 0:
                    cuda_home = ''
                    cuda_home = cuda_homes[0]
                cuda_home = '/usr/local/cuda'
            if not os.path.exists(cuda_home):
                cuda_home = None
    if cuda_home and not torch.cuda.is_available():
        print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'")
    return cuda_home

def _find_rocm_home() -> Optional[str]:
    r'''Finds the ROCm install path.'''
    # Guess #1
    rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
    if rocm_home is None:
        # Guess #2
            hipcc = subprocess.check_output(
                ['which', 'hipcc'], stderr=subprocess.DEVNULL).decode().rstrip('\r\n')
            # this will be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
            rocm_home = os.path.dirname(os.path.dirname(hipcc))
            if os.path.basename(rocm_home) == 'hip':
                rocm_home = os.path.dirname(rocm_home)
        except Exception:
            # Guess #3
            rocm_home = '/opt/rocm'
            if not os.path.exists(rocm_home):
                rocm_home = None
    if rocm_home and torch.version.hip is None:
        print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'")
    return rocm_home

def _join_rocm_home(*paths) -> str:
    Joins paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.

    This is basically a lazy way of raising an error for missing $ROCM_HOME
    only once we need to get any ROCm-specific path.
    if ROCM_HOME is None:
        raise EnvironmentError('ROCM_HOME environment variable is not set. '
                               'Please set it to your ROCm install root.')
    elif IS_WINDOWS:
        raise EnvironmentError('Building PyTorch extensions using '
                               'ROCm and Windows is not supported.')
    return os.path.join(ROCM_HOME, *paths)

MINIMUM_MSVC_VERSION = (19, 0, 24215)

                               !! WARNING !!

Your compiler ({}) may be ABI-incompatible with PyTorch!
Please use a compiler that is ABI-compatible with GCC 5.0 and above.
See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.

See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6
for instructions on how to install GCC 5 or higher.

                              !! WARNING !!

                               !! WARNING !!

Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was
built with for this platform, which is {pytorch_compiler} on {platform}. Please
use {pytorch_compiler} to to compile your extension. Alternatively, you may
compile PyTorch from source using {user_compiler}, and then you can also use
{user_compiler} to compile your extension.

See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
with compiling PyTorch from source.

                              !! WARNING !!
ROCM_HOME = _find_rocm_home()
MIOPEN_HOME = _join_rocm_home('miopen') if ROCM_HOME else None
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
if torch.version.hip is not None:
    ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])

CUDA_HOME = _find_cuda_home()
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
# PyTorch releases have the version pattern major.minor.patch, whereas when
# PyTorch is built from source, we append the git commit hash, which gives
# it the below pattern.
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')

COMMON_MSVC_FLAGS = ['/MD', '/wd4819', '/wd4251', '/wd4244', '/wd4267', '/wd4275', '/wd4018', '/wd4190', '/EHsc']





JIT_EXTENSION_VERSIONER = ExtensionVersioner()

    'win32' : 'x86',
    'win-amd64' : 'x86_amd64',

def _is_binary_build() -> bool:
    return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)

def _accepted_compilers_for_platform() -> List[str]:
    # gnu-c++ and gnu-cc are the conda gcc compilers
    return ['clang++', 'clang'] if sys.platform.startswith('darwin') else ['g++', 'gcc', 'gnu-c++', 'gnu-cc']

def get_default_build_root() -> str:
    Returns the path to the root folder under which extensions will built.

    For each extension module built, there will be one folder underneath the
    folder returned by this function. For example, if ``p`` is the path
    returned by this function and ``ext`` the name of an extension, the build
    folder for the extension will be ``p/ext``.

    This directory is **user-specific** so that multiple users on the same
    machine won't meet permission issues.
    return os.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions'))

def check_compiler_ok_for_platform(compiler: str) -> bool:
    Verifies that the compiler is the expected one for the current platform.

        compiler (str): The compiler executable to check.

        True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
        and always True for Windows.
    if IS_WINDOWS:
        return True
    which = subprocess.check_output(['which', compiler], stderr=subprocess.STDOUT)
    # Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.
    compiler_path = os.path.realpath(which.decode().strip())
    # Check the compiler name
    if any(name in compiler_path for name in _accepted_compilers_for_platform()):
        return True
    # If ccache is used the compiler path is /usr/bin/ccache. Check by -v flag.
    version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT).decode()
    if sys.platform.startswith('linux'):
        # Check for 'gcc' or 'g++'
        pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
        results = re.findall(pattern, version_string)
        if len(results) != 1:
            return False
        compiler_path = os.path.realpath(results[0].strip())
        return any(name in compiler_path for name in _accepted_compilers_for_platform())
    if sys.platform.startswith('darwin'):
        # Check for 'clang' or 'clang++'
        return version_string.startswith("Apple clang")
    return False

def check_compiler_abi_compatibility(compiler) -> bool:
    Verifies that the given compiler is ABI-compatible with PyTorch.

        compiler (str): The compiler executable name to check (e.g. ``g++``).
            Must be executable in a shell process.

        False if the compiler is (likely) ABI-incompatible with PyTorch,
        else True.
    if not _is_binary_build():
        return True
    if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']:
        return True

    # First check if the compiler is one of the expected ones for the particular platform.
    if not check_compiler_ok_for_platform(compiler):
        return False

    if sys.platform.startswith('darwin'):
        # There is no particular minimum version we need for clang, so we're good here.
        return True
        if sys.platform.startswith('linux'):
            minimum_required_version = MINIMUM_GCC_VERSION
            versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
            version = versionstr.decode().strip().split('.')
            minimum_required_version = MINIMUM_MSVC_VERSION
            compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
            match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip())
            version = (0, 0, 0) if match is None else match.groups()
    except Exception:
        _, error, _ = sys.exc_info()
        warnings.warn(f'Error checking compiler version for {compiler}: {error}')
        return False

    if tuple(map(int, version)) >= minimum_required_version:
        return True

    compiler = f'{compiler} {".".join(version)}'

    return False

# See below for why we inherit BuildExtension from object.
# https://stackoverflow.com/questions/1713038/super-fails-with-error-typeerror-argument-1-must-be-type-not-classobj-when

class BuildExtension(build_ext, object):
    A custom :mod:`setuptools` build extension .

    This :class:`setuptools.build_ext` subclass takes care of passing the
    minimum required compiler flags (e.g. ``-std=c++14``) as well as mixed
    C++/CUDA compilation (and support for CUDA files in general).

    When using :class:`BuildExtension`, it is allowed to supply a dictionary
    for ``extra_compile_args`` (rather than the usual list) that maps from
    languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to
    supply to the compiler. This makes it possible to supply different flags to
    the C++ and CUDA compiler during mixed compilation.

    ``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we
    attempt to build using the Ninja backend. Ninja greatly speeds up
    compilation compared to the standard ``setuptools.build_ext``.
    Fallbacks to the standard distutils backend if Ninja is not available.

    .. note::
        By default, the Ninja backend uses #CPUS + 2 workers to build the
        extension. This may use up too many resources on some systems. One
        can control the number of workers by setting the `MAX_JOBS` environment
        variable to a non-negative number.
Loading ...