Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ testing / _internal / common_utils.py

r"""Importing this file must **not** initialize CUDA context. test_distributed
relies on this assumption to properly run. This means that when this is imported
no CUDA calls shall be made, including torch.cuda.device_count(), etc.

torch.testing._internal.common_cuda.py can freely initialize CUDA context when imported.
"""

import sys
import os
import platform
import re
import gc
import types
import math
from functools import partial
import inspect
import io
import copy
import operator
import argparse
import unittest
import warnings
import random
import contextlib
import shutil
import socket
import subprocess
import time
from collections import OrderedDict
from collections.abc import Sequence
from contextlib import contextmanager
from functools import wraps
from itertools import product
from copy import deepcopy
from numbers import Number
import tempfile
import json
from urllib.request import urlopen
import __main__  # type: ignore[import]
import errno
from typing import cast, Any, Dict, Iterable, Iterator, Optional

from torch.testing._internal import expecttest
from torch.testing import \
    (_compare_tensors_internal, _compare_scalars_internal, _compare_return_type,
     floating_types_and, integral_types, complex_types)

import torch
import torch.cuda
from torch._utils_internal import get_writable_path
from torch._six import string_classes
import torch.backends.cudnn
import torch.backends.mkl
from enum import Enum

torch.backends.disable_global_flags()

FILE_SCHEMA = "file://"
if sys.platform == 'win32':
    FILE_SCHEMA = "file:///"

IS_SANDCASTLE = os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle'
IS_FBCODE = os.getenv('PYTORCH_TEST_FBCODE') == '1'
IS_REMOTE_GPU = os.getenv('PYTORCH_TEST_REMOTE_GPU') == '1'

class ProfilingMode(Enum):
    LEGACY = 1
    SIMPLE = 2
    PROFILING = 3

def cppProfilingFlagsToProfilingMode():
    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
    old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
    torch._C._jit_set_profiling_executor(old_prof_exec_state)
    torch._C._jit_set_profiling_mode(old_prof_mode_state)

    if old_prof_exec_state:
        if old_prof_mode_state:
            return ProfilingMode.PROFILING
        else:
            return ProfilingMode.SIMPLE
    else:
        return ProfilingMode.LEGACY

@contextmanager
def enable_profiling_mode_for_profiling_tests():
    if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
        old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
        old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
    try:
        yield
    finally:
        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
            torch._C._jit_set_profiling_executor(old_prof_exec_state)
            torch._C._jit_set_profiling_mode(old_prof_mode_state)

@contextmanager
def enable_profiling_mode():
    old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
    old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
    try:
        yield
    finally:
        torch._C._jit_set_profiling_executor(old_prof_exec_state)
        torch._C._jit_set_profiling_mode(old_prof_mode_state)

@contextmanager
def num_profiled_runs(num_runs):
    old_num_runs = torch._C._jit_set_num_profiled_runs(num_runs)
    try:
        yield
    finally:
        torch._C._jit_set_num_profiled_runs(old_num_runs)

func_call = torch._C.ScriptFunction.__call__
meth_call = torch._C.ScriptMethod.__call__

def prof_callable(callable, *args, **kwargs):
    if 'profile_and_replay' in kwargs:
        del kwargs['profile_and_replay']
        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
            with enable_profiling_mode_for_profiling_tests():
                callable(*args, **kwargs)
                return callable(*args, **kwargs)

    return callable(*args, **kwargs)

def prof_func_call(*args, **kwargs):
    return prof_callable(func_call, *args, **kwargs)

def prof_meth_call(*args, **kwargs):
    return prof_callable(meth_call, *args, **kwargs)

# TODO fix when https://github.com/python/mypy/issues/2427 is address
torch._C.ScriptFunction.__call__ = prof_func_call  # type: ignore[assignment]
torch._C.ScriptMethod.__call__ = prof_meth_call  # type: ignore[assignment]

def _get_test_report_path():
    # allow users to override the test file location. We need this
    # because the distributed tests run the same test file multiple
    # times with different configurations.
    override = os.environ.get('TEST_REPORT_SOURCE_OVERRIDE')
    test_source = override if override is not None else 'python-unittest'
    return os.path.join('test-reports', test_source)


parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--subprocess', action='store_true',
                    help='whether to run each test in a subprocess')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--accept', action='store_true')
parser.add_argument('--jit_executor', type=str)
parser.add_argument('--repeat', type=int, default=1)
parser.add_argument('--test_bailouts', action='store_true')
parser.add_argument('--save-xml', nargs='?', type=str,
                    const=_get_test_report_path(),
                    default=_get_test_report_path() if bool(os.environ.get('IN_CI')) else None)
parser.add_argument('--discover-tests', action='store_true')
parser.add_argument('--log-suffix', type=str, default="")
parser.add_argument('--run-parallel', type=int, default=1)

args, remaining = parser.parse_known_args()
if args.jit_executor == 'legacy':
    GRAPH_EXECUTOR = ProfilingMode.LEGACY
elif args.jit_executor == 'profiling':
    GRAPH_EXECUTOR = ProfilingMode.PROFILING
elif args.jit_executor == 'simple':
    GRAPH_EXECUTOR = ProfilingMode.SIMPLE
else:
    # infer flags based on the default settings
    GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode()


LOG_SUFFIX = args.log_suffix
RUN_PARALLEL = args.run_parallel
TEST_BAILOUTS = args.test_bailouts
TEST_DISCOVER = args.discover_tests
TEST_IN_SUBPROCESS = args.subprocess
TEST_SAVE_XML = args.save_xml
REPEAT_COUNT = args.repeat
SEED = args.seed
if not expecttest.ACCEPT:
    expecttest.ACCEPT = args.accept
UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)

def wait_for_process(p):
    try:
        return p.wait()
    except KeyboardInterrupt:
        # Give `p` a chance to handle KeyboardInterrupt. Without this,
        # `pytest` can't print errors it collected so far upon KeyboardInterrupt.
        exit_status = p.wait(timeout=5)
        if exit_status is not None:
            return exit_status
        else:
            p.kill()
            raise
    except:  # noqa E722, copied from python core library
        p.kill()
        raise
    finally:
        # Always call p.wait() to ensure exit
        p.wait()

def shell(command, cwd=None, env=None):
    sys.stdout.flush()
    sys.stderr.flush()
    # The following cool snippet is copied from Py3 core library subprocess.call
    # only the with
    #   1. `except KeyboardInterrupt` block added for SIGINT handling.
    #   2. In Py2, subprocess.Popen doesn't return a context manager, so we do
    #      `p.wait()` in a `final` block for the code to be portable.
    #
    # https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
    assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
    p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env)
    return wait_for_process(p)


# Used to run the same test with different tensor types
def repeat_test_for_types(dtypes):
    def repeat_helper(f):
        @wraps(f)
        def call_helper(self, *args):
            for dtype in dtypes:
                with TestCase.subTest(self, dtype=dtype):
                    f(self, *args, dtype=dtype)

        return call_helper
    return repeat_helper

# Environment variable `IS_PYTORCH_CI` is set in `.jenkins/common.sh`.
IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI'))


def discover_test_cases_recursively(suite_or_case):
    if isinstance(suite_or_case, unittest.TestCase):
        return [suite_or_case]
    rc = []
    for element in suite_or_case:
        rc.extend(discover_test_cases_recursively(element))
    return rc

def get_test_names(test_cases):
    return ['.'.join(case.id().split('.')[-2:]) for case in test_cases]

def chunk_list(lst, nchunks):
    return [lst[i::nchunks] for i in range(nchunks)]


def run_tests(argv=UNITTEST_ARGS):
    if TEST_DISCOVER:
        suite = unittest.TestLoader().loadTestsFromModule(__main__)
        test_cases = discover_test_cases_recursively(suite)
        for name in get_test_names(test_cases):
            print(name)
    elif TEST_IN_SUBPROCESS:
        suite = unittest.TestLoader().loadTestsFromModule(__main__)
        test_cases = discover_test_cases_recursively(suite)
        failed_tests = []
        for case in test_cases:
            test_case_full_name = case.id().split('.', 1)[1]
            exitcode = shell([sys.executable] + argv + [test_case_full_name])
            if exitcode != 0:
                failed_tests.append(test_case_full_name)

        assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
            len(failed_tests), '\n\t'.join(failed_tests))
    elif RUN_PARALLEL > 1:
        suite = unittest.TestLoader().loadTestsFromModule(__main__)
        test_cases = discover_test_cases_recursively(suite)
        test_batches = chunk_list(get_test_names(test_cases), RUN_PARALLEL)
        processes = []
        for i in range(RUN_PARALLEL):
            command = [sys.executable] + argv + ['--log-suffix=-shard-{}'.format(i + 1)] + test_batches[i]
            processes.append(subprocess.Popen(command, universal_newlines=True))
        failed = False
        for p in processes:
            failed |= wait_for_process(p) != 0
        assert not failed, "Some test shards have failed"
    elif TEST_SAVE_XML is not None:
        # import here so that non-CI doesn't need xmlrunner installed
        import xmlrunner  # type: ignore[import]
        test_report_path = TEST_SAVE_XML + LOG_SUFFIX
        os.makedirs(test_report_path, exist_ok=True)
        verbose = '--verbose' in argv or '-v' in argv
        if verbose:
            print('Test results will be stored in {}'.format(test_report_path))
        unittest.main(argv=argv, testRunner=xmlrunner.XMLTestRunner(output=test_report_path, verbosity=2 if verbose else 1))
    elif REPEAT_COUNT > 1:
        for _ in range(REPEAT_COUNT):
            if not unittest.main(exit=False, argv=argv).result.wasSuccessful():
                sys.exit(-1)
    else:
        unittest.main(argv=argv)

IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
IS_PPC = platform.machine() == "ppc64le"

if IS_WINDOWS:
    @contextmanager
    def TemporaryFileName(*args, **kwargs):
        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
        # close the file after creation and try to remove it manually
        if 'delete' in kwargs:
            if kwargs['delete'] is not False:
                raise UserWarning("only TemporaryFileName with delete=False is supported on Windows.")
        else:
            kwargs['delete'] = False
        f = tempfile.NamedTemporaryFile(*args, **kwargs)
        try:
            f.close()
            yield f.name
        finally:
            os.unlink(f.name)
else:
    @contextmanager  # noqa: T484
    def TemporaryFileName(*args, **kwargs):
        with tempfile.NamedTemporaryFile(*args, **kwargs) as f:
            yield f.name

if IS_WINDOWS:
    @contextmanager
    def TemporaryDirectoryName(suffix=None):
        # On Windows the directory created by TemporaryDirectory is likely to be removed prematurely,
        # so we first create the directory using mkdtemp and then remove it manually
        try:
            dir_name = tempfile.mkdtemp(suffix=suffix)
            yield dir_name
        finally:
            shutil.rmtree(dir_name)
else:
    @contextmanager  # noqa: T484
    def TemporaryDirectoryName(suffix=None):
        with tempfile.TemporaryDirectory(suffix=suffix) as d:
            yield d

IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'

def _check_module_exists(name):
    r"""Returns if a top-level module with :attr:`name` exists *without**
    importing it. This is generally safer than try-catch block around a
Loading ...