Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pymongo / test / test_pooling.py
Size: Mime:
# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test built in connection-pooling with threads."""

import gc
import random
import socket
import sys
import threading
import time

from bson.son import SON
from bson.codec_options import DEFAULT_CODEC_OPTIONS

from pymongo import MongoClient, message
from pymongo.errors import (AutoReconnect,
                            ConnectionFailure,
                            DuplicateKeyError,
                            ExceededMaxWaiters)

sys.path[0:0] = [""]

from pymongo.pool import Pool, PoolOptions
from pymongo.socket_checker import SocketChecker
from test import client_context, IntegrationTest, unittest
from test.utils import (get_pool,
                        joinall,
                        delay,
                        rs_or_single_client)


@client_context.require_connection
def setUpModule():
    pass

N = 10
DB = "pymongo-pooling-tests"


def gc_collect_until_done(threads, timeout=60):
    start = time.time()
    running = list(threads)
    while running:
        assert (time.time() - start) < timeout, "Threads timed out"
        for t in running:
            t.join(0.1)
            if not t.is_alive():
                running.remove(t)
        gc.collect()


class MongoThread(threading.Thread):
    """A thread that uses a MongoClient."""
    def __init__(self, client):
        super(MongoThread, self).__init__()
        self.daemon = True  # Don't hang whole test if thread hangs.
        self.client = client
        self.db = self.client[DB]
        self.passed = False

    def run(self):
        self.run_mongo_thread()
        self.passed = True

    def run_mongo_thread(self):
        raise NotImplementedError


class InsertOneAndFind(MongoThread):
    def run_mongo_thread(self):
        for _ in range(N):
            rand = random.randint(0, N)
            _id = self.db.sf.insert_one({"x": rand}).inserted_id
            assert rand == self.db.sf.find_one(_id)["x"]


class Unique(MongoThread):
    def run_mongo_thread(self):
        for _ in range(N):
            self.db.unique.insert_one({})  # no error


class NonUnique(MongoThread):
    def run_mongo_thread(self):
        for _ in range(N):
            try:
                self.db.unique.insert_one({"_id": "jesse"})
            except DuplicateKeyError:
                pass
            else:
                raise AssertionError("Should have raised DuplicateKeyError")


class Disconnect(MongoThread):
    def run_mongo_thread(self):
        for _ in range(N):
            self.client.close()


class SocketGetter(MongoThread):
    """Utility for TestPooling.

    Checks out a socket and holds it forever. Used in
    test_no_wait_queue_timeout, test_wait_queue_multiple, and
    test_no_wait_queue_multiple.
    """
    def __init__(self, client, pool):
        super(SocketGetter, self).__init__(client)
        self.state = 'init'
        self.pool = pool
        self.sock = None

    def run_mongo_thread(self):
        self.state = 'get_socket'

        # Call 'pin_cursor' so we can hold the socket.
        with self.pool.get_socket({}) as sock:
            sock.pin_cursor()
            self.sock = sock

        self.state = 'sock'

    def __del__(self):
        if self.sock:
            self.sock.close_socket(None)


def run_cases(client, cases):
    threads = []
    n_runs = 5

    for case in cases:
        for i in range(n_runs):
            t = case(client)
            t.start()
            threads.append(t)

    for t in threads:
        t.join()

    for t in threads:
        assert t.passed, "%s.run() threw an exception" % repr(t)


class _TestPoolingBase(IntegrationTest):
    """Base class for all connection-pool tests."""

    def setUp(self):
        super(_TestPoolingBase, self).setUp()
        self.c = rs_or_single_client()
        db = self.c[DB]
        db.unique.drop()
        db.test.drop()
        db.unique.insert_one({"_id": "jesse"})
        db.test.insert_many([{} for _ in range(10)])

    def tearDown(self):
        self.c.close()
        super(_TestPoolingBase, self).tearDown()

    def create_pool(
            self,
            pair=(client_context.host, client_context.port),
            *args,
            **kwargs):
        # Start the pool with the correct ssl options.
        pool_options = client_context.client._topology_settings.pool_options
        kwargs['ssl_context'] = pool_options.ssl_context
        kwargs['ssl_match_hostname'] = pool_options.ssl_match_hostname
        kwargs['server_api'] = pool_options.server_api
        return Pool(pair, PoolOptions(*args, **kwargs))


class TestPooling(_TestPoolingBase):
    def test_max_pool_size_validation(self):
        host, port = client_context.host, client_context.port
        self.assertRaises(
            ValueError, MongoClient, host=host, port=port, maxPoolSize=-1)

        self.assertRaises(
            ValueError, MongoClient, host=host, port=port, maxPoolSize='foo')

        c = MongoClient(host=host, port=port, maxPoolSize=100, connect=False)
        self.assertEqual(c.max_pool_size, 100)

    def test_no_disconnect(self):
        run_cases(self.c, [NonUnique, Unique, InsertOneAndFind])

    def test_disconnect(self):
        run_cases(self.c, [InsertOneAndFind, Disconnect, Unique])

    def test_pool_reuses_open_socket(self):
        # Test Pool's _check_closed() method doesn't close a healthy socket.
        cx_pool = self.create_pool(max_pool_size=10)
        cx_pool._check_interval_seconds = 0  # Always check.
        with cx_pool.get_socket({}) as sock_info:
            pass

        with cx_pool.get_socket({}) as new_sock_info:
            self.assertEqual(sock_info, new_sock_info)

        self.assertEqual(1, len(cx_pool.sockets))

    def test_get_socket_and_exception(self):
        # get_socket() returns socket after a non-network error.
        cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
        with self.assertRaises(ZeroDivisionError):
            with cx_pool.get_socket({}) as sock_info:
                1 / 0

        # Socket was returned, not closed.
        with cx_pool.get_socket({}) as new_sock_info:
            self.assertEqual(sock_info, new_sock_info)

        self.assertEqual(1, len(cx_pool.sockets))

    def test_pool_removes_closed_socket(self):
        # Test that Pool removes explicitly closed socket.
        cx_pool = self.create_pool()

        with cx_pool.get_socket({}) as sock_info:
            # Use SocketInfo's API to close the socket.
            sock_info.close_socket(None)

        self.assertEqual(0, len(cx_pool.sockets))

    def test_pool_removes_dead_socket(self):
        # Test that Pool removes dead socket and the socket doesn't return
        # itself PYTHON-344
        cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1)
        cx_pool._check_interval_seconds = 0  # Always check.

        with cx_pool.get_socket({}) as sock_info:
            # Simulate a closed socket without telling the SocketInfo it's
            # closed.
            sock_info.sock.close()
            self.assertTrue(sock_info.socket_closed())

        with cx_pool.get_socket({}) as new_sock_info:
            self.assertEqual(0, len(cx_pool.sockets))
            self.assertNotEqual(sock_info, new_sock_info)

        self.assertEqual(1, len(cx_pool.sockets))

        # Semaphore was released.
        with cx_pool.get_socket({}):
            pass

    def test_socket_closed(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect((client_context.host, client_context.port))
        socket_checker = SocketChecker()
        self.assertFalse(socket_checker.socket_closed(s))
        s.close()
        self.assertTrue(socket_checker.socket_closed(s))

    def test_socket_checker(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect((client_context.host, client_context.port))
        socket_checker = SocketChecker()
        # Socket has nothing to read.
        self.assertFalse(socket_checker.select(s, read=True))
        self.assertFalse(socket_checker.select(s, read=True, timeout=0))
        self.assertFalse(socket_checker.select(s, read=True, timeout=.05))
        # Socket is writable.
        self.assertTrue(socket_checker.select(s, write=True, timeout=None))
        self.assertTrue(socket_checker.select(s, write=True))
        self.assertTrue(socket_checker.select(s, write=True, timeout=0))
        self.assertTrue(socket_checker.select(s, write=True, timeout=.05))
        # Make the socket readable
        _, msg, _ = message.query(
            0, 'admin.$cmd', 0, -1, SON([('ping', 1)]), None,
            DEFAULT_CODEC_OPTIONS)
        s.sendall(msg)
        # Block until the socket is readable.
        self.assertTrue(socket_checker.select(s, read=True, timeout=None))
        self.assertTrue(socket_checker.select(s, read=True))
        self.assertTrue(socket_checker.select(s, read=True, timeout=0))
        self.assertTrue(socket_checker.select(s, read=True, timeout=.05))
        # Socket is still writable.
        self.assertTrue(socket_checker.select(s, write=True, timeout=None))
        self.assertTrue(socket_checker.select(s, write=True))
        self.assertTrue(socket_checker.select(s, write=True, timeout=0))
        self.assertTrue(socket_checker.select(s, write=True, timeout=.05))
        s.close()
        self.assertTrue(socket_checker.socket_closed(s))

    def test_return_socket_after_reset(self):
        pool = self.create_pool()
        with pool.get_socket({}) as sock:
            pool.reset()

        self.assertTrue(sock.closed)
        self.assertEqual(0, len(pool.sockets))

    def test_pool_check(self):
        # Test that Pool recovers from two connection failures in a row.
        # This exercises code at the end of Pool._check().
        cx_pool = self.create_pool(max_pool_size=1,
                                   connect_timeout=1,
                                   wait_queue_timeout=1)
        cx_pool._check_interval_seconds = 0  # Always check.
        self.addCleanup(cx_pool.close)

        with cx_pool.get_socket({}) as sock_info:
            # Simulate a closed socket without telling the SocketInfo it's
            # closed.
            sock_info.sock.close()

        # Swap pool's address with a bad one.
        address, cx_pool.address = cx_pool.address, ('foo.com', 1234)
        with self.assertRaises(AutoReconnect):
            with cx_pool.get_socket({}):
                pass

        # Back to normal, semaphore was correctly released.
        cx_pool.address = address
        with cx_pool.get_socket({}):
            pass

    def test_wait_queue_timeout(self):
        wait_queue_timeout = 2  # Seconds
        pool = self.create_pool(
            max_pool_size=1, wait_queue_timeout=wait_queue_timeout)
        self.addCleanup(pool.close)

        with pool.get_socket({}) as sock_info:
            start = time.time()
            with self.assertRaises(ConnectionFailure):
                with pool.get_socket({}):
                    pass

        duration = time.time() - start
        self.assertTrue(
            abs(wait_queue_timeout - duration) < 1,
            "Waited %.2f seconds for a socket, expected %f" % (
                duration, wait_queue_timeout))


    def test_no_wait_queue_timeout(self):
        # Verify get_socket() with no wait_queue_timeout blocks forever.
        pool = self.create_pool(max_pool_size=1)
        self.addCleanup(pool.close)

        # Reach max_size.
        with pool.get_socket({}) as s1:
            t = SocketGetter(self.c, pool)
            t.start()
            while t.state != 'get_socket':
                time.sleep(0.1)

            time.sleep(1)
            self.assertEqual(t.state, 'get_socket')

        while t.state != 'sock':
            time.sleep(0.1)

        self.assertEqual(t.state, 'sock')
        self.assertEqual(t.sock, s1)

    def test_wait_queue_multiple(self):
        wait_queue_multiple = 3
        pool = self.create_pool(
            max_pool_size=2, wait_queue_multiple=wait_queue_multiple)

        # Reach max_size sockets.
        with pool.get_socket({}):
            with pool.get_socket({}):

                # Reach max_size * wait_queue_multiple waiters.
                threads = []
                for _ in range(6):
                    t = SocketGetter(self.c, pool)
                    t.start()
                    threads.append(t)

                time.sleep(1)
                for t in threads:
                    self.assertEqual(t.state, 'get_socket')

                with self.assertRaises(ExceededMaxWaiters):
                    with pool.get_socket({}):
                        pass

    def test_no_wait_queue_multiple(self):
        pool = self.create_pool(max_pool_size=2)

        socks = []
        for _ in range(2):
            # Call 'pin_cursor' so we can hold the socket.
            with pool.get_socket({}) as sock:
                sock.pin_cursor()
                socks.append(sock)

        threads = []
        for _ in range(30):
            t = SocketGetter(self.c, pool)
            t.start()
            threads.append(t)
        time.sleep(1)
        for t in threads:
            self.assertEqual(t.state, 'get_socket')

        for socket_info in socks:
            socket_info.close_socket(None)


class TestPoolMaxSize(_TestPoolingBase):
    def test_max_pool_size(self):
        max_pool_size = 4
        c = rs_or_single_client(maxPoolSize=max_pool_size)
        self.addCleanup(c.close)
        collection = c[DB].test

        # Need one document.
        collection.drop()
        collection.insert_one({})

        # nthreads had better be much larger than max_pool_size to ensure that
        # max_pool_size sockets are actually required at some point in this
        # test's execution.
        cx_pool = get_pool(c)
        nthreads = 10
        threads = []
        lock = threading.Lock()
        self.n_passed = 0

        def f():
            for _ in range(5):
                collection.find_one({'$where': delay(0.1)})
                assert len(cx_pool.sockets) <= max_pool_size

            with lock:
                self.n_passed += 1

        for i in range(nthreads):
            t = threading.Thread(target=f)
            threads.append(t)
            t.start()

        joinall(threads)
        self.assertEqual(nthreads, self.n_passed)
        self.assertTrue(len(cx_pool.sockets) > 1)
        self.assertEqual(max_pool_size, cx_pool._socket_semaphore.counter)

    def test_max_pool_size_none(self):
        c = rs_or_single_client(maxPoolSize=None)
        self.addCleanup(c.close)
        collection = c[DB].test

        # Need one document.
        collection.drop()
        collection.insert_one({})

        cx_pool = get_pool(c)
        nthreads = 10
        threads = []
        lock = threading.Lock()
        self.n_passed = 0

        def f():
            for _ in range(5):
                collection.find_one({'$where': delay(0.1)})

            with lock:
                self.n_passed += 1

        for i in range(nthreads):
            t = threading.Thread(target=f)
            threads.append(t)
            t.start()

        joinall(threads)
        self.assertEqual(nthreads, self.n_passed)
        self.assertTrue(len(cx_pool.sockets) > 1)

    def test_max_pool_size_zero(self):
        with self.assertRaises(ValueError):
            rs_or_single_client(maxPoolSize=0)

    def test_max_pool_size_with_connection_failure(self):
        # The pool acquires its semaphore before attempting to connect; ensure
        # it releases the semaphore on connection failure.
        test_pool = Pool(
            ('somedomainthatdoesntexist.org', 27017),
            PoolOptions(
                max_pool_size=1,
                connect_timeout=1,
                socket_timeout=1,
                wait_queue_timeout=1))

        # First call to get_socket fails; if pool doesn't release its semaphore
        # then the second call raises "ConnectionFailure: Timed out waiting for
        # socket from pool" instead of AutoReconnect.
        for i in range(2):
            with self.assertRaises(AutoReconnect) as context:
                with test_pool.get_socket({}):
                    pass

            # Testing for AutoReconnect instead of ConnectionFailure, above,
            # is sufficient right *now* to catch a semaphore leak. But that
            # seems error-prone, so check the message too.
            self.assertNotIn('waiting for socket from pool',
                             str(context.exception))


if __name__ == "__main__":
    unittest.main()