Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pymongo / test / utils_spec_runner.py
Size: Mime:
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for testing driver specs."""

import copy
import functools
import threading


from bson import decode, encode
from bson.binary import Binary, STANDARD
from bson.codec_options import CodecOptions
from bson.int64 import Int64
from bson.py3compat import iteritems, abc, string_type, text_type
from bson.son import SON

from gridfs import GridFSBucket

from pymongo import (client_session,
                     helpers,
                     operations)
from pymongo.command_cursor import CommandCursor
from pymongo.cursor import Cursor
from pymongo.errors import (BulkWriteError,
                            OperationFailure,
                            PyMongoError)
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import _WriteResult, BulkWriteResult
from pymongo.write_concern import WriteConcern

from test import (client_context,
                  client_knobs,
                  IntegrationTest,
                  unittest)
from test.utils import (camel_to_snake,
                        camel_to_snake_args,
                        camel_to_upper_camel,
                        CompareType,
                        CMAPListener,
                        OvertCommandListener,
                        parse_spec_options,
                        parse_read_preference,
                        prepare_spec_arguments,
                        rs_client,
                        ServerAndTopologyEventListener,
                        HeartbeatEventListener)


class SpecRunnerThread(threading.Thread):
    def __init__(self, name):
        super(SpecRunnerThread, self).__init__()
        self.name = name
        self.exc = None
        self.setDaemon(True)
        self.cond = threading.Condition()
        self.ops = []
        self.stopped = False

    def schedule(self, work):
        self.ops.append(work)
        with self.cond:
            self.cond.notify()

    def stop(self):
        self.stopped = True
        with self.cond:
            self.cond.notify()

    def run(self):
        while not self.stopped or self.ops:
            if not self. ops:
                with self.cond:
                    self.cond.wait(10)
            if self.ops:
                try:
                    work = self.ops.pop(0)
                    work()
                except Exception as exc:
                    self.exc = exc
                    self.stop()


class SpecRunner(IntegrationTest):

    @classmethod
    def setUpClass(cls):
        super(SpecRunner, cls).setUpClass()
        cls.mongos_clients = []

        # Speed up the tests by decreasing the heartbeat frequency.
        cls.knobs = client_knobs(heartbeat_frequency=0.1,
                                 min_heartbeat_interval=0.1)
        cls.knobs.enable()

    @classmethod
    def tearDownClass(cls):
        cls.knobs.disable()
        super(SpecRunner, cls).tearDownClass()

    def setUp(self):
        super(SpecRunner, self).setUp()
        self.targets = {}
        self.listener = None
        self.pool_listener = None
        self.server_listener = None
        self.maxDiff = None

    def _set_fail_point(self, client, command_args):
        cmd = SON([('configureFailPoint', 'failCommand')])
        cmd.update(command_args)
        client.admin.command(cmd)

    def set_fail_point(self, command_args):
        cmd = SON([('configureFailPoint', 'failCommand')])
        cmd.update(command_args)
        clients = self.mongos_clients if self.mongos_clients else [self.client]
        for client in clients:
            self._set_fail_point(client, cmd)

    def targeted_fail_point(self, session, fail_point):
        """Run the targetedFailPoint test operation.

        Enable the fail point on the session's pinned mongos.
        """
        clients = {c.address: c for c in self.mongos_clients}
        client = clients[session._pinned_address]
        self._set_fail_point(client, fail_point)
        self.addCleanup(self.set_fail_point, {'mode': 'off'})

    def assert_session_pinned(self, session):
        """Run the assertSessionPinned test operation.

        Assert that the given session is pinned.
        """
        self.assertIsNotNone(session._transaction.pinned_address)

    def assert_session_unpinned(self, session):
        """Run the assertSessionUnpinned test operation.

        Assert that the given session is not pinned.
        """
        self.assertIsNone(session._pinned_address)
        self.assertIsNone(session._transaction.pinned_address)

    def assert_collection_exists(self, database, collection):
        """Run the assertCollectionExists test operation."""
        db = self.client[database]
        self.assertIn(collection, db.list_collection_names())

    def assert_collection_not_exists(self, database, collection):
        """Run the assertCollectionNotExists test operation."""
        db = self.client[database]
        self.assertNotIn(collection, db.list_collection_names())

    def assert_index_exists(self, database, collection, index):
        """Run the assertIndexExists test operation."""
        coll = self.client[database][collection]
        self.assertIn(index, [doc['name'] for doc in coll.list_indexes()])

    def assert_index_not_exists(self, database, collection, index):
        """Run the assertIndexNotExists test operation."""
        coll = self.client[database][collection]
        self.assertNotIn(index, [doc['name'] for doc in coll.list_indexes()])

    def assertErrorLabelsContain(self, exc, expected_labels):
        labels = [l for l in expected_labels if exc.has_error_label(l)]
        self.assertEqual(labels, expected_labels)

    def assertErrorLabelsOmit(self, exc, omit_labels):
        for label in omit_labels:
            self.assertFalse(
                exc.has_error_label(label),
                msg='error labels should not contain %s' % (label,))

    def kill_all_sessions(self):
        clients = self.mongos_clients if self.mongos_clients else [self.client]
        for client in clients:
            try:
                client.admin.command('killAllSessions', [])
            except OperationFailure:
                # "operation was interrupted" by killing the command's
                # own session.
                pass

    def check_command_result(self, expected_result, result):
        # Only compare the keys in the expected result.
        filtered_result = {}
        for key in expected_result:
            try:
                filtered_result[key] = result[key]
            except KeyError:
                pass
        self.assertEqual(filtered_result, expected_result)

    # TODO: factor the following function with test_crud.py.
    def check_result(self, expected_result, result):
        if isinstance(result, _WriteResult):
            for res in expected_result:
                prop = camel_to_snake(res)
                # SPEC-869: Only BulkWriteResult has upserted_count.
                if (prop == "upserted_count"
                        and not isinstance(result, BulkWriteResult)):
                    if result.upserted_id is not None:
                        upserted_count = 1
                    else:
                        upserted_count = 0
                    self.assertEqual(upserted_count, expected_result[res], prop)
                elif prop == "inserted_ids":
                    # BulkWriteResult does not have inserted_ids.
                    if isinstance(result, BulkWriteResult):
                        self.assertEqual(len(expected_result[res]),
                                         result.inserted_count)
                    else:
                        # InsertManyResult may be compared to [id1] from the
                        # crud spec or {"0": id1} from the retryable write spec.
                        ids = expected_result[res]
                        if isinstance(ids, dict):
                            ids = [ids[str(i)] for i in range(len(ids))]
                        self.assertEqual(ids, result.inserted_ids, prop)
                elif prop == "upserted_ids":
                    # Convert indexes from strings to integers.
                    ids = expected_result[res]
                    expected_ids = {}
                    for str_index in ids:
                        expected_ids[int(str_index)] = ids[str_index]
                    self.assertEqual(expected_ids, result.upserted_ids, prop)
                else:
                    self.assertEqual(
                        getattr(result, prop), expected_result[res], prop)

            return True
        else:
            self.assertEqual(result, expected_result)

    def get_object_name(self, op):
        """Allow subclasses to override handling of 'object'

        Transaction spec says 'object' is required.
        """
        return op['object']

    @staticmethod
    def parse_options(opts):
        return parse_spec_options(opts)

    def run_operation(self, sessions, collection, operation):
        original_collection = collection
        name = camel_to_snake(operation['name'])
        if name == 'run_command':
            name = 'command'
        elif name == 'download_by_name':
            name = 'open_download_stream_by_name'
        elif name == 'download':
            name = 'open_download_stream'

        database = collection.database
        collection = database.get_collection(collection.name)
        if 'collectionOptions' in operation:
            collection = collection.with_options(
                **self.parse_options(operation['collectionOptions']))

        object_name = self.get_object_name(operation)
        if object_name == 'gridfsbucket':
            # Only create the GridFSBucket when we need it (for the gridfs
            # retryable reads tests).
            obj = GridFSBucket(
                database, bucket_name=collection.name,
                disable_md5=True)
        else:
            objects = {
                'client': database.client,
                'database': database,
                'collection': collection,
                'testRunner': self
            }
            objects.update(sessions)
            obj = objects[object_name]

        # Combine arguments with options and handle special cases.
        arguments = operation.get('arguments', {})
        arguments.update(arguments.pop("options", {}))
        self.parse_options(arguments)

        cmd = getattr(obj, name)

        with_txn_callback = functools.partial(
            self.run_operations, sessions, original_collection,
            in_with_transaction=True)
        prepare_spec_arguments(operation, arguments, name, sessions,
                               with_txn_callback)

        if name == 'run_on_thread':
            args = {'sessions': sessions, 'collection': collection}
            args.update(arguments)
            arguments = args
        result = cmd(**dict(arguments))

        # Cleanup open change stream cursors.
        if name == "watch":
            self.addCleanup(result.close)

        if name == "aggregate":
            if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]:
                # Read from the primary to ensure causal consistency.
                out = collection.database.get_collection(
                    arguments["pipeline"][-1]["$out"],
                    read_preference=ReadPreference.PRIMARY)
                return out.find()
        if name == "map_reduce":
            if isinstance(result, dict) and 'results' in result:
                return result['results']
        if 'download' in name:
            result = Binary(result.read())

        if isinstance(result, Cursor) or isinstance(result, CommandCursor):
            return list(result)

        return result

    def allowable_errors(self, op):
        """Allow encryption spec to override expected error classes."""
        return (PyMongoError,)

    def _run_op(self, sessions, collection, op, in_with_transaction):
        expected_result = op.get('result')
        if expect_error(op):
            with self.assertRaises(self.allowable_errors(op),
                                   msg=op['name']) as context:
                self.run_operation(sessions, collection, op.copy())

            if expect_error_message(expected_result):
                if isinstance(context.exception, BulkWriteError):
                    errmsg = str(context.exception.details).lower()
                else:
                    errmsg = str(context.exception).lower()
                self.assertIn(expected_result['errorContains'].lower(),
                              errmsg)
            if expect_error_code(expected_result):
                self.assertEqual(expected_result['errorCodeName'],
                                 context.exception.details.get('codeName'))
            if expect_error_labels_contain(expected_result):
                self.assertErrorLabelsContain(
                    context.exception,
                    expected_result['errorLabelsContain'])
            if expect_error_labels_omit(expected_result):
                self.assertErrorLabelsOmit(
                    context.exception,
                    expected_result['errorLabelsOmit'])

            # Reraise the exception if we're in the with_transaction
            # callback.
            if in_with_transaction:
                raise context.exception
        else:
            result = self.run_operation(sessions, collection, op.copy())
            if 'result' in op:
                if op['name'] == 'runCommand':
                    self.check_command_result(expected_result, result)
                else:
                    self.check_result(expected_result, result)

    def run_operations(self, sessions, collection, ops,
                       in_with_transaction=False):
        for op in ops:
            self._run_op(sessions, collection, op, in_with_transaction)

    # TODO: factor with test_command_monitoring.py
    def check_events(self, test, listener, session_ids):
        res = listener.results
        if not len(test['expectations']):
            return

        # Give a nicer message when there are missing or extra events
        cmds = decode_raw([event.command for event in res['started']])
        self.assertEqual(
            len(res['started']), len(test['expectations']), cmds)
        for i, expectation in enumerate(test['expectations']):
            event_type = next(iter(expectation))
            event = res['started'][i]

            # The tests substitute 42 for any number other than 0.
            if (event.command_name == 'getMore'
                    and event.command['getMore']):
                event.command['getMore'] = Int64(42)
            elif event.command_name == 'killCursors':
                event.command['cursors'] = [Int64(42)]
            elif event.command_name == 'update':
                # TODO: remove this once PYTHON-1744 is done.
                # Add upsert and multi fields back into expectations.
                updates = expectation[event_type]['command']['updates']
                for update in updates:
                    update.setdefault('upsert', False)
                    update.setdefault('multi', False)

            # Replace afterClusterTime: 42 with actual afterClusterTime.
            expected_cmd = expectation[event_type]['command']
            expected_read_concern = expected_cmd.get('readConcern')
            if expected_read_concern is not None:
                time = expected_read_concern.get('afterClusterTime')
                if time == 42:
                    actual_time = event.command.get(
                        'readConcern', {}).get('afterClusterTime')
                    if actual_time is not None:
                        expected_read_concern['afterClusterTime'] = actual_time

            recovery_token = expected_cmd.get('recoveryToken')
            if recovery_token == 42:
                expected_cmd['recoveryToken'] = CompareType(dict)

            # Replace lsid with a name like "session0" to match test.
            if 'lsid' in event.command:
                for name, lsid in session_ids.items():
                    if event.command['lsid'] == lsid:
                        event.command['lsid'] = name
                        break

            for attr, expected in expectation[event_type].items():
                actual = getattr(event, attr)
                expected = wrap_types(expected)
                if isinstance(expected, dict):
                    for key, val in expected.items():
                        if val is None:
                            if key in actual:
                                self.fail("Unexpected key [%s] in %r" % (
                                    key, actual))
                        elif key not in actual:
                            self.fail("Expected key [%s] in %r" % (
                                key, actual))
                        else:
                            self.assertEqual(val, decode_raw(actual[key]),
                                             "Key [%s] in %s" % (key, actual))
                else:
                    self.assertEqual(actual, expected)

    def maybe_skip_scenario(self, test):
        if test.get('skipReason'):
            self.skipTest(test.get('skipReason'))

    def get_scenario_db_name(self, scenario_def):
        """Allow subclasses to override a test's database name."""
        return scenario_def['database_name']

    def get_scenario_coll_name(self, scenario_def):
        """Allow subclasses to override a test's collection name."""
        return scenario_def['collection_name']

    def get_outcome_coll_name(self, outcome, collection):
        """Allow subclasses to override outcome collection."""
        return collection.name

    def run_test_ops(self, sessions, collection, test):
        """Added to allow retryable writes spec to override a test's
        operation."""
        self.run_operations(sessions, collection, test['operations'])

    def parse_client_options(self, opts):
        """Allow encryption spec to override a clientOptions parsing."""
        # Convert test['clientOptions'] to dict to avoid a Jython bug using
        # "**" with ScenarioDict.
        return dict(opts)

    def setup_scenario(self, scenario_def):
        """Allow specs to override a test's setup."""
        db_name = self.get_scenario_db_name(scenario_def)
        coll_name = self.get_scenario_coll_name(scenario_def)
        db = client_context.client.get_database(
            db_name, write_concern=WriteConcern(w='majority'))
        coll = db[coll_name]
        coll.drop()
        db.create_collection(coll_name)
        if scenario_def['data']:
            # Load data.
            coll.insert_many(scenario_def['data'])

    def run_scenario(self, scenario_def, test):
        self.maybe_skip_scenario(test)

        # Kill all sessions before and after each test to prevent an open
        # transaction (from a test failure) from blocking collection/database
        # operations during test set up and tear down.
        self.kill_all_sessions()
        self.addCleanup(self.kill_all_sessions)
        self.setup_scenario(scenario_def)
        database_name = self.get_scenario_db_name(scenario_def)
        collection_name = self.get_scenario_coll_name(scenario_def)
        # SPEC-1245 workaround StaleDbVersion on distinct
        for c in self.mongos_clients:
            c[database_name][collection_name].distinct("x")

        # Configure the fail point before creating the client.
        if 'failPoint' in test:
            fp = test['failPoint']
            self.set_fail_point(fp)
            self.addCleanup(self.set_fail_point, {
                'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'})

        listener = OvertCommandListener()
        pool_listener = CMAPListener()
        server_listener = ServerAndTopologyEventListener()
        # Create a new client, to avoid interference from pooled sessions.
        client_options = self.parse_client_options(test['clientOptions'])
        # MMAPv1 does not support retryable writes.
        if (client_options.get('retryWrites') is True and
                client_context.storage_engine == 'mmapv1'):
            self.skipTest("MMAPv1 does not support retryWrites=True")
        use_multi_mongos = test['useMultipleMongoses']
        host = None
        if use_multi_mongos:
            if client_context.load_balancer or client_context.serverless:
                host = client_context.MULTI_MONGOS_LB_URI
            elif client_context.is_mongos:
                host = client_context.mongos_seeds()
        client = rs_client(
            h=host,
            event_listeners=[listener, pool_listener, server_listener],
            **client_options)
        self.scenario_client = client
        self.listener = listener
        self.pool_listener = pool_listener
        self.server_listener = server_listener
        # Close the client explicitly to avoid having too many threads open.
        self.addCleanup(client.close)

        # Create session0 and session1.
        sessions = {}
        session_ids = {}
        for i in range(2):
            # Don't attempt to create sessions if they are not supported by
            # the running server version.
            if not client_context.sessions_enabled:
                break
            session_name = 'session%d' % i
            opts = camel_to_snake_args(test['sessionOptions'][session_name])
            if 'default_transaction_options' in opts:
                txn_opts = self.parse_options(
                    opts['default_transaction_options'])
                txn_opts = client_session.TransactionOptions(**txn_opts)
                opts['default_transaction_options'] = txn_opts

            s = client.start_session(**dict(opts))

            sessions[session_name] = s
            # Store lsid so we can access it after end_session, in check_events.
            session_ids[session_name] = s.session_id

        self.addCleanup(end_sessions, sessions)

        collection = client[database_name][collection_name]
        self.run_test_ops(sessions, collection, test)

        end_sessions(sessions)

        self.check_events(test, listener, session_ids)

        # Disable fail points.
        if 'failPoint' in test:
            fp = test['failPoint']
            self.set_fail_point({
                'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'})

        # Assert final state is expected.
        outcome = test['outcome']
        expected_c = outcome.get('collection')
        if expected_c is not None:
            outcome_coll_name = self.get_outcome_coll_name(
                outcome, collection)

            # Read from the primary with local read concern to ensure causal
            # consistency.
            outcome_coll = client_context.client[
                collection.database.name].get_collection(
                outcome_coll_name,
                read_preference=ReadPreference.PRIMARY,
                read_concern=ReadConcern('local'))
            actual_data = list(outcome_coll.find(sort=[('_id', 1)]))

            # The expected data needs to be the left hand side here otherwise
            # CompareType(Binary) doesn't work.
            self.assertEqual(wrap_types(expected_c['data']), actual_data)


def expect_any_error(op):
    if isinstance(op, dict):
        return op.get('error')

    return False


def expect_error_message(expected_result):
    if isinstance(expected_result, dict):
        return isinstance(expected_result['errorContains'], text_type)

    return False


def expect_error_code(expected_result):
    if isinstance(expected_result, dict):
        return expected_result['errorCodeName']

    return False


def expect_error_labels_contain(expected_result):
    if isinstance(expected_result, dict):
        return expected_result['errorLabelsContain']

    return False


def expect_error_labels_omit(expected_result):
    if isinstance(expected_result, dict):
        return expected_result['errorLabelsOmit']

    return False


def expect_error(op):
    expected_result = op.get('result')
    return (expect_any_error(op) or
            expect_error_message(expected_result)
            or expect_error_code(expected_result)
            or expect_error_labels_contain(expected_result)
            or expect_error_labels_omit(expected_result))


def end_sessions(sessions):
    for s in sessions.values():
        # Aborts the transaction if it's open.
        s.end_session()


OPTS = CodecOptions(document_class=dict, uuid_representation=STANDARD)


def decode_raw(val):
    """Decode RawBSONDocuments in the given container."""
    if isinstance(val, (list, abc.Mapping)):
        return decode(encode({'v': val}, codec_options=OPTS), OPTS)['v']
    return val


TYPES = {
    'binData': Binary,
    'long': Int64,
}


def wrap_types(val):
    """Support $$type assertion in command results."""
    if isinstance(val, list):
        return [wrap_types(v) for v in val]
    if isinstance(val, abc.Mapping):
        typ = val.get('$$type')
        if typ:
            return CompareType(TYPES[typ])
        d = {}
        for key in val:
            d[key] = wrap_types(val[key])
        return d
    return val