Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _dynamo / test_case.py

import contextlib
import importlib
import sys

import torch
import torch.testing
from torch.testing._internal.common_utils import (
    IS_WINDOWS,
    TEST_WITH_CROSSREF,
    TEST_WITH_ROCM,
    TEST_WITH_TORCHDYNAMO,
    TestCase as TorchTestCase,
)

from . import config, reset, utils


def run_tests(needs=()):
    from torch.testing._internal.common_utils import run_tests

    if (
        TEST_WITH_TORCHDYNAMO
        or IS_WINDOWS
        or TEST_WITH_CROSSREF
        or TEST_WITH_ROCM
        or sys.version_info >= (3, 11)
    ):
        return  # skip testing

    if isinstance(needs, str):
        needs = (needs,)
    for need in needs:
        if need == "cuda" and not torch.cuda.is_available():
            return
        else:
            try:
                importlib.import_module(need)
            except ImportError:
                return
    run_tests()


class TestCase(TorchTestCase):
    @classmethod
    def tearDownClass(cls):
        cls._exit_stack.close()
        super().tearDownClass()

    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls._exit_stack = contextlib.ExitStack()
        cls._exit_stack.enter_context(
            config.patch(raise_on_ctx_manager_usage=True, suppress_errors=False),
        )

    def setUp(self):
        super().setUp()
        reset()
        utils.counters.clear()

    def tearDown(self):
        for k, v in utils.counters.items():
            print(k, v.most_common())
        reset()
        utils.counters.clear()
        super().tearDown()