Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

agriconnect / uvloop   python

Repository URL to install this package:

Version: 0.9.1 

/ _testbase.py

"""Test utilities. Don't use outside of the uvloop project."""


import asyncio
import asyncio.events
import collections
import contextlib
import gc
import logging
import os
import pprint
import re
import select
import socket
import ssl
import tempfile
import threading
import time
import unittest
import uvloop


class MockPattern(str):
    def __eq__(self, other):
        return bool(re.search(str(self), other, re.S))


class TestCaseDict(collections.UserDict):

    def __init__(self, name):
        super().__init__()
        self.name = name

    def __setitem__(self, key, value):
        if key in self.data:
            raise RuntimeError('duplicate test {}.{}'.format(
                self.name, key))
        super().__setitem__(key, value)


class BaseTestCaseMeta(type):

    @classmethod
    def __prepare__(mcls, name, bases):
        return TestCaseDict(name)

    def __new__(mcls, name, bases, dct):
        for test_name in dct:
            if not test_name.startswith('test_'):
                continue
            for base in bases:
                if hasattr(base, test_name):
                    raise RuntimeError(
                        'duplicate test {}.{} (also defined in {} '
                        'parent class)'.format(
                            name, test_name, base.__name__))

        return super().__new__(mcls, name, bases, dict(dct))


class BaseTestCase(unittest.TestCase, metaclass=BaseTestCaseMeta):

    def new_loop(self):
        raise NotImplementedError

    def mock_pattern(self, str):
        return MockPattern(str)

    def is_asyncio_loop(self):
        return type(self.loop).__module__.startswith('asyncio.')

    def run_loop_briefly(self, *, delay=0.01):
        self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop))

    def loop_exception_handler(self, loop, context):
        self.__unhandled_exceptions.append(context)
        self.loop.default_exception_handler(context)

    def setUp(self):
        self.loop = self.new_loop()
        asyncio.set_event_loop(None)
        self._check_unclosed_resources_in_debug = True

        self.loop.set_exception_handler(self.loop_exception_handler)
        self.__unhandled_exceptions = []

        if hasattr(asyncio, '_get_running_loop'):
            # Disable `_get_running_loop`.
            self._get_running_loop = asyncio.events._get_running_loop
            asyncio.events._get_running_loop = lambda: None

    def tearDown(self):
        self.loop.close()

        if self.__unhandled_exceptions:
            print('Unexpected calls to loop.call_exception_handler():')
            pprint.pprint(self.__unhandled_exceptions)
            self.fail('unexpected calls to loop.call_exception_handler()')
            return

        if hasattr(asyncio, '_get_running_loop'):
            asyncio.events._get_running_loop = self._get_running_loop

        if not self._check_unclosed_resources_in_debug:
            return

        # GC to show any resource warnings as the test completes
        gc.collect()
        gc.collect()
        gc.collect()

        if getattr(self.loop, '_debug_cc', False):
            gc.collect()
            gc.collect()
            gc.collect()

            self.assertEqual(
                self.loop._debug_uv_handles_total,
                self.loop._debug_uv_handles_freed,
                'not all uv_handle_t handles were freed')

            self.assertEqual(
                self.loop._debug_cb_handles_count, 0,
                'not all callbacks (call_soon) are GCed')

            self.assertEqual(
                self.loop._debug_cb_timer_handles_count, 0,
                'not all timer callbacks (call_later) are GCed')

            self.assertEqual(
                self.loop._debug_stream_write_ctx_cnt, 0,
                'not all stream write contexts are GCed')

            for h_name, h_cnt in self.loop._debug_handles_current.items():
                with self.subTest('Alive handle after test',
                                  handle_name=h_name):
                    self.assertEqual(
                        h_cnt, 0,
                        'alive {} after test'.format(h_name))

            for h_name, h_cnt in self.loop._debug_handles_total.items():
                with self.subTest('Total/closed handles',
                                  handle_name=h_name):
                    self.assertEqual(
                        h_cnt, self.loop._debug_handles_closed[h_name],
                        'total != closed for {}'.format(h_name))

        asyncio.set_event_loop(None)
        self.loop = None

    def skip_unclosed_handles_check(self):
        self._check_unclosed_resources_in_debug = False

    def tcp_server(self, server_prog, *,
                   family=socket.AF_INET,
                   addr=None,
                   timeout=5,
                   backlog=1,
                   max_clients=10):

        if addr is None:
            if family == socket.AF_UNIX:
                with tempfile.NamedTemporaryFile() as tmp:
                    addr = tmp.name
            else:
                addr = ('127.0.0.1', 0)

        sock = socket.socket(family, socket.SOCK_STREAM)

        if timeout is None:
            raise RuntimeError('timeout is required')
        if timeout <= 0:
            raise RuntimeError('only blocking sockets are supported')
        sock.settimeout(timeout)

        try:
            sock.bind(addr)
            sock.listen(backlog)
        except OSError as ex:
            sock.close()
            raise ex

        return TestThreadedServer(
            self, sock, server_prog, timeout, max_clients)

    def tcp_client(self, client_prog,
                   family=socket.AF_INET,
                   timeout=10):

        sock = socket.socket(family, socket.SOCK_STREAM)

        if timeout is None:
            raise RuntimeError('timeout is required')
        if timeout <= 0:
            raise RuntimeError('only blocking sockets are supported')
        sock.settimeout(timeout)

        return TestThreadedClient(
            self, sock, client_prog, timeout)

    def unix_server(self, *args, **kwargs):
        return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)

    def unix_client(self, *args, **kwargs):
        return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)

    @contextlib.contextmanager
    def unix_sock_name(self):
        with tempfile.TemporaryDirectory() as td:
            fn = os.path.join(td, 'sock')
            try:
                yield fn
            finally:
                try:
                    os.unlink(fn)
                except OSError:
                    pass

    def _abort_socket_test(self, ex):
        try:
            self.loop.stop()
        finally:
            self.fail(ex)


def _cert_fullname(test_file_name, cert_file_name):
    fullname = os.path.abspath(os.path.join(
        os.path.dirname(test_file_name), 'certs', cert_file_name))
    assert os.path.isfile(fullname)
    return fullname


@contextlib.contextmanager
def silence_long_exec_warning():

    class Filter(logging.Filter):
        def filter(self, record):
            return not (record.msg.startswith('Executing') and
                        record.msg.endswith('seconds'))

    logger = logging.getLogger('asyncio')
    filter = Filter()
    logger.addFilter(filter)
    try:
        yield
    finally:
        logger.removeFilter(filter)


def find_free_port(start_from=50000):
    for port in range(start_from, start_from + 500):
        sock = socket.socket()
        with sock:
            try:
                sock.bind(('', port))
            except socket.error:
                continue
            else:
                return port
    raise RuntimeError('could not find a free port')


class SSLTestCase:

    def _create_server_ssl_context(self, certfile, keyfile=None):
        sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        sslcontext.options |= ssl.OP_NO_SSLv2
        sslcontext.load_cert_chain(certfile, keyfile)
        return sslcontext

    def _create_client_ssl_context(self):
        sslcontext = ssl.create_default_context()
        sslcontext.check_hostname = False
        sslcontext.verify_mode = ssl.CERT_NONE
        return sslcontext

    @contextlib.contextmanager
    def _silence_eof_received_warning(self):
        # TODO This warning has to be fixed in asyncio.
        logger = logging.getLogger('asyncio')
        filter = logging.Filter('has no effect when using ssl')
        logger.addFilter(filter)
        try:
            yield
        finally:
            logger.removeFilter(filter)


class UVTestCase(BaseTestCase):

    implementation = 'uvloop'

    def new_loop(self):
        return uvloop.new_event_loop()


class AIOTestCase(BaseTestCase):

    implementation = 'asyncio'

    def setUp(self):
        super().setUp()

        watcher = asyncio.SafeChildWatcher()
        watcher.attach_loop(self.loop)
        asyncio.set_child_watcher(watcher)

    def tearDown(self):
        asyncio.set_child_watcher(None)
        super().tearDown()

    def new_loop(self):
        return asyncio.new_event_loop()


def has_IPv6():
    server_sock = socket.socket(socket.AF_INET6)
    with server_sock:
        try:
            server_sock.bind(('::1', 0))
        except OSError:
            return False
        else:
            return True


has_IPv6 = has_IPv6()


###############################################################################
# Socket Testing Utilities
###############################################################################


class TestSocketWrapper:

    def __init__(self, sock):
        self.__sock = sock

    def recv_all(self, n):
        buf = b''
        while len(buf) < n:
            data = self.recv(n - len(buf))
            if data == b'':
                raise ConnectionAbortedError
Loading ...