Repository URL to install this package:
|
Version:
3.12.2 ▾
|
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unified test format runner.
https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst
"""
import copy
import datetime
import functools
import os
import re
import sys
import types
from bson import json_util, Code, Decimal128, DBRef, SON, Int64, MaxKey, MinKey
from bson.binary import Binary
from bson.objectid import ObjectId
from bson.py3compat import abc, integer_types, iteritems, text_type, PY3
from bson.regex import Regex, RE_TYPE
from gridfs import GridFSBucket
from pymongo import ASCENDING, MongoClient
from pymongo.client_session import ClientSession, TransactionOptions, _TxnState
from pymongo.change_stream import ChangeStream
from pymongo.collection import Collection
from pymongo.database import Database
from pymongo.errors import (
BulkWriteError, ConnectionFailure, ConfigurationError, InvalidOperation,
NotPrimaryError, PyMongoError)
from pymongo.monitoring import (
CommandFailedEvent, CommandListener, CommandStartedEvent,
CommandSucceededEvent, _SENSITIVE_COMMANDS, PoolCreatedEvent,
PoolClearedEvent, PoolClosedEvent, ConnectionCreatedEvent,
ConnectionReadyEvent, ConnectionClosedEvent,
ConnectionCheckOutStartedEvent, ConnectionCheckOutFailedEvent,
ConnectionCheckedOutEvent, ConnectionCheckedInEvent)
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult
from pymongo.server_api import ServerApi
from pymongo.write_concern import WriteConcern
from test import client_context, unittest, IntegrationTest
from test.utils import (
camel_to_snake, get_pool, rs_or_single_client, single_client,
snake_to_camel, CMAPListener)
from test.version import Version
from test.utils import (
camel_to_snake_args, parse_collection_options, parse_spec_options,
prepare_spec_arguments)
JSON_OPTS = json_util.JSONOptions(tz_aware=False)
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass.
Vendored from six: https://github.com/benjaminp/six/blob/master/six.py
"""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):
def __new__(cls, name, this_bases, d):
if sys.version_info[:2] >= (3, 7):
# This version introduced PEP 560 that requires a bit
# of extra care (we mimic what is done by __build_class__).
resolved_bases = types.resolve_bases(bases)
if resolved_bases is not bases:
d['__orig_bases__'] = bases
else:
resolved_bases = bases
return meta(name, resolved_bases, d)
@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
def is_run_on_requirement_satisfied(requirement):
topology_satisfied = True
req_topologies = requirement.get('topologies')
if req_topologies:
topology_satisfied = client_context.is_topology_type(
req_topologies)
server_version = Version(*client_context.version[:3])
min_version_satisfied = True
req_min_server_version = requirement.get('minServerVersion')
if req_min_server_version:
min_version_satisfied = Version.from_string(
req_min_server_version) <= server_version
max_version_satisfied = True
req_max_server_version = requirement.get('maxServerVersion')
if req_max_server_version:
max_version_satisfied = Version.from_string(
req_max_server_version) >= server_version
serverless = requirement.get('serverless')
if serverless == "require":
serverless_satisfied = client_context.serverless
elif serverless == "forbid":
serverless_satisfied = not client_context.serverless
else: # unset or "allow"
serverless_satisfied = True
params_satisfied = True
params = requirement.get('serverParameters')
if params:
for param, val in params.items():
if param not in client_context.server_parameters:
params_satisfied = False
elif client_context.server_parameters[param] != val:
params_satisfied = False
auth_satisfied = True
req_auth = requirement.get('auth')
if req_auth is not None:
if req_auth:
auth_satisfied = client_context.auth_enabled
else:
auth_satisfied = not client_context.auth_enabled
return (topology_satisfied and min_version_satisfied and
max_version_satisfied and serverless_satisfied and
params_satisfied and auth_satisfied)
def parse_collection_or_database_options(options):
return parse_collection_options(options)
def parse_bulk_write_result(result):
upserted_ids = {str(int_idx): result.upserted_ids[int_idx]
for int_idx in result.upserted_ids}
return {
'deletedCount': result.deleted_count,
'insertedCount': result.inserted_count,
'matchedCount': result.matched_count,
'modifiedCount': result.modified_count,
'upsertedCount': result.upserted_count,
'upsertedIds': upserted_ids}
def parse_bulk_write_error_result(error):
write_result = BulkWriteResult(error.details, True)
return parse_bulk_write_result(write_result)
class NonLazyCursor(object):
"""A find cursor proxy that creates the remote cursor when initialized."""
def __init__(self, find_cursor):
self.find_cursor = find_cursor
# Create the server side cursor.
self.first_result = next(find_cursor, None)
def __iter__(self):
return self
def __next__(self):
if self.first_result is not None:
first = self.first_result
self.first_result = None
return first
return next(self.find_cursor)
next = __next__
def close(self):
self.find_cursor.close()
class EventListenerUtil(CMAPListener, CommandListener):
def __init__(self, observe_events, ignore_commands,
observe_sensitive_commands):
self._event_types = set(name.lower() for name in observe_events)
if observe_sensitive_commands:
self._observe_sensitive_commands = True
self._ignore_commands = set(ignore_commands)
else:
self._observe_sensitive_commands = False
self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands)
self._ignore_commands.add('configurefailpoint')
super(EventListenerUtil, self).__init__()
def get_events(self, event_type):
if event_type == 'command':
return [e for e in self.events if 'Command' in type(e).__name__]
return [e for e in self.events if 'Command' not in type(e).__name__]
def add_event(self, event):
if type(event).__name__.lower() in self._event_types:
super(EventListenerUtil, self).add_event(event)
def _command_event(self, event):
if event.command_name.lower() not in self._ignore_commands:
self.add_event(event)
def started(self, event):
if event.command == {}:
# Command is redacted. Observe only if flag is set.
if self._observe_sensitive_commands:
self._command_event(event)
else:
self._command_event(event)
def succeeded(self, event):
if event.reply == {}:
# Command is redacted. Observe only if flag is set.
if self._observe_sensitive_commands:
self._command_event(event)
else:
self._command_event(event)
def failed(self, event):
self._command_event(event)
class EntityMapUtil(object):
"""Utility class that implements an entity map as per the unified
test format specification."""
def __init__(self, test_class):
self._entities = {}
self._listeners = {}
self._session_lsids = {}
self.test = test_class
def __getitem__(self, item):
try:
return self._entities[item]
except KeyError:
self.test.fail('Could not find entity named %s in map' % (
item,))
def __setitem__(self, key, value):
if not isinstance(key, text_type):
self.test.fail(
'Expected entity name of type str, got %s' % (type(key)))
if key in self._entities:
self.test.fail('Entity named %s already in map' % (key,))
self._entities[key] = value
def _create_entity(self, entity_spec):
if len(entity_spec) != 1:
self.test.fail(
"Entity spec %s did not contain exactly one top-level key" % (
entity_spec,))
entity_type, spec = next(iteritems(entity_spec))
if entity_type == 'client':
kwargs = {}
observe_events = spec.get('observeEvents', [])
ignore_commands = spec.get('ignoreCommandMonitoringEvents', [])
observe_sensitive_commands = spec.get(
'observeSensitiveCommands', False)
# TODO: PYTHON-2511 support storeEventsAsEntities
if len(observe_events) or len(ignore_commands):
ignore_commands = [cmd.lower() for cmd in ignore_commands]
listener = EventListenerUtil(
observe_events, ignore_commands, observe_sensitive_commands)
self._listeners[spec['id']] = listener
kwargs['event_listeners'] = [listener]
if spec.get('useMultipleMongoses'):
if client_context.load_balancer or client_context.serverless:
kwargs['h'] = client_context.MULTI_MONGOS_LB_URI
elif client_context.is_mongos:
kwargs['h'] = client_context.mongos_seeds()
kwargs.update(spec.get('uriOptions', {}))
server_api = spec.get('serverApi')
if server_api:
kwargs['server_api'] = ServerApi(
server_api['version'], strict=server_api.get('strict'),
deprecation_errors=server_api.get('deprecationErrors'))
client = rs_or_single_client(**kwargs)
self[spec['id']] = client
self.test.addCleanup(client.close)
return
elif entity_type == 'database':
client = self[spec['client']]
if not isinstance(client, MongoClient):
self.test.fail(
'Expected entity %s to be of type MongoClient, got %s' % (
spec['client'], type(client)))
options = parse_collection_or_database_options(
spec.get('databaseOptions', {}))
self[spec['id']] = client.get_database(
spec['databaseName'], **options)
return
elif entity_type == 'collection':
database = self[spec['database']]
if not isinstance(database, Database):
self.test.fail(
'Expected entity %s to be of type Database, got %s' % (
spec['database'], type(database)))
options = parse_collection_or_database_options(
spec.get('collectionOptions', {}))
self[spec['id']] = database.get_collection(
spec['collectionName'], **options)
return
elif entity_type == 'session':
client = self[spec['client']]
if not isinstance(client, MongoClient):
self.test.fail(
'Expected entity %s to be of type MongoClient, got %s' % (
spec['client'], type(client)))
opts = camel_to_snake_args(spec.get('sessionOptions', {}))
if 'default_transaction_options' in opts:
txn_opts = parse_spec_options(
opts['default_transaction_options'])
txn_opts = TransactionOptions(**txn_opts)
opts = copy.deepcopy(opts)
opts['default_transaction_options'] = txn_opts
session = client.start_session(**dict(opts))
self[spec['id']] = session
self._session_lsids[spec['id']] = copy.deepcopy(session.session_id)
self.test.addCleanup(session.end_session)
return
elif entity_type == 'bucket':
# TODO: implement the 'bucket' entity type
self.test.skipTest(
'GridFS is not currently supported (PYTHON-2459)')
self.test.fail(
'Unable to create entity of unknown type %s' % (entity_type,))
def create_entities_from_spec(self, entity_spec):
for spec in entity_spec:
self._create_entity(spec)
def get_listener_for_client(self, client_name):
client = self[client_name]
if not isinstance(client, MongoClient):
self.test.fail(
'Expected entity %s to be of type MongoClient, got %s' % (
client_name, type(client)))
listener = self._listeners.get(client_name)
if not listener:
self.test.fail(
'No listeners configured for client %s' % (client_name,))
return listener
def get_lsid_for_session(self, session_name):
session = self[session_name]
if not isinstance(session, ClientSession):
self.test.fail(
'Expected entity %s to be of type ClientSession, got %s' % (
session_name, type(session)))
try:
return session.session_id
except InvalidOperation:
# session has been closed.
return self._session_lsids[session_name]
if not PY3:
binary_types = (Binary,)
long_types = (Int64, long)
unicode_type = unicode
else:
binary_types = (Binary, bytes)
long_types = (Int64,)
unicode_type = str
BSON_TYPE_ALIAS_MAP = {
# https://docs.mongodb.com/manual/reference/operator/query/type/
# https://pymongo.readthedocs.io/en/stable/api/bson/index.html
'double': (float,),
'string': (text_type,),
'object': (abc.Mapping,),
'array': (abc.MutableSequence,),
'binData': binary_types,
'undefined': (type(None),),
'objectId': (ObjectId,),
'bool': (bool,),
'date': (datetime.datetime,),
'null': (type(None),),
'regex': (Regex, RE_TYPE),
'dbPointer': (DBRef,),
'javascript': (unicode_type, Code),
'symbol': (unicode_type,),
'javascriptWithScope': (unicode_type, Code),
'int': (int,),
'long': long_types,
'decimal': (Decimal128,),
'maxKey': (MaxKey,),
'minKey': (MinKey,),
}
class MatchEvaluatorUtil(object):
"""Utility class that implements methods for evaluating matches as per
the unified test format specification."""
def __init__(self, test_class):
self.test = test_class
def _operation_exists(self, spec, actual, key_to_compare):
if spec is True:
self.test.assertIn(key_to_compare, actual)
elif spec is False:
self.test.assertNotIn(key_to_compare, actual)
else:
self.test.fail(
'Expected boolean value for $$exists operator, got %s' % (
spec,))
def __type_alias_to_type(self, alias):
if alias not in BSON_TYPE_ALIAS_MAP:
self.test.fail('Unrecognized BSON type alias %s' % (alias,))
return BSON_TYPE_ALIAS_MAP[alias]
def _operation_type(self, spec, actual, key_to_compare):
if isinstance(spec, abc.MutableSequence):
permissible_types = tuple([
t for alias in spec for t in self.__type_alias_to_type(alias)])
else:
permissible_types = self.__type_alias_to_type(spec)
value = actual[key_to_compare] if key_to_compare else actual
self.test.assertIsInstance(value, permissible_types)
def _operation_matchesEntity(self, spec, actual, key_to_compare):
expected_entity = self.test.entity_map[spec]
self.test.assertIsInstance(expected_entity, abc.Mapping)
self.test.assertEqual(expected_entity, actual[key_to_compare])
def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
raise NotImplementedError
def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
if key_to_compare is None and not actual:
# top-level document can be None when unset
return
if key_to_compare not in actual:
# we add a dummy value for the compared key to pass map size check
actual[key_to_compare] = 'dummyValue'
return
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
def _operation_sessionLsid(self, spec, actual, key_to_compare):
expected_lsid = self.test.entity_map.get_lsid_for_session(spec)
self.test.assertEqual(expected_lsid, actual[key_to_compare])
def _evaluate_special_operation(self, opname, spec, actual,
key_to_compare):
method_name = '_operation_%s' % (opname.strip('$'),)
try:
method = getattr(self, method_name)
except AttributeError:
self.test.fail(
'Unsupported special matching operator %s' % (opname,))
else:
method(spec, actual, key_to_compare)
def _evaluate_if_special_operation(self, expectation, actual,
key_to_compare=None):
"""Returns True if a special operation is evaluated, False
otherwise. If the ``expectation`` map contains a single key,
value pair we check it for a special operation.
If given, ``key_to_compare`` is assumed to be the key in
``expectation`` whose corresponding value needs to be
evaluated for a possible special operation. ``key_to_compare``
is ignored when ``expectation`` has only one key."""
if not isinstance(expectation, abc.Mapping):
return False
is_special_op, opname, spec = False, False, False
if key_to_compare is not None:
if key_to_compare.startswith('$$'):
is_special_op = True
opname = key_to_compare
spec = expectation[key_to_compare]
key_to_compare = None
else:
nested = expectation[key_to_compare]
if isinstance(nested, abc.Mapping) and len(nested) == 1:
opname, spec = next(iteritems(nested))
if opname.startswith('$$'):
is_special_op = True
elif len(expectation) == 1:
opname, spec = next(iteritems(expectation))
if opname.startswith('$$'):
is_special_op = True
key_to_compare = None
if is_special_op:
self._evaluate_special_operation(
opname=opname,
spec=spec,
actual=actual,
key_to_compare=key_to_compare)
return True
return False
def _match_document(self, expectation, actual, is_root):
if self._evaluate_if_special_operation(expectation, actual):
return
self.test.assertIsInstance(actual, abc.Mapping)
for key, value in iteritems(expectation):
if self._evaluate_if_special_operation(expectation, actual, key):
continue
self.test.assertIn(key, actual)
self.match_result(value, actual[key], in_recursive_call=True)
if not is_root:
expected_keys = set(expectation.keys())
for key, value in expectation.items():
if value == {'$$exists': False}:
expected_keys.remove(key)
self.test.assertEqual(expected_keys, set(actual.keys()))
def match_result(self, expectation, actual,
in_recursive_call=False):
if isinstance(expectation, abc.Mapping):
return self._match_document(
expectation, actual, is_root=not in_recursive_call)
if isinstance(expectation, abc.MutableSequence):
self.test.assertIsInstance(actual, abc.MutableSequence)
for e, a in zip(expectation, actual):
if isinstance(e, abc.Mapping):
self._match_document(
e, a, is_root=not in_recursive_call)
else:
self.match_result(e, a, in_recursive_call=True)
return
# account for flexible numerics in element-wise comparison and unicode
# vs str on Python 2.
if not (isinstance(expectation, integer_types) or
isinstance(expectation, float) or
(isinstance(expectation, unicode_type) and not PY3)):
self.test.assertIsInstance(actual, type(expectation))
self.test.assertEqual(expectation, actual)
def assertHasServiceId(self, spec, actual):
if 'hasServiceId' in spec:
if spec.get('hasServiceId'):
self.test.assertIsNotNone(actual.service_id)
self.test.assertIsInstance(actual.service_id, ObjectId)
else:
self.test.assertIsNone(actual.service_id)
def match_event(self, event_type, expectation, actual):
name, spec = next(iteritems(expectation))
# every command event has the commandName field
if event_type == 'command':
command_name = spec.get('commandName')
if command_name:
self.test.assertEqual(command_name, actual.command_name)
if name == 'commandStartedEvent':
self.test.assertIsInstance(actual, CommandStartedEvent)
command = spec.get('command')
database_name = spec.get('databaseName')
if command:
if actual.command_name == 'update':
# TODO: remove this once PYTHON-1744 is done.
# Add upsert and multi fields back into expectations.
for update in command.get('updates', []):
update.setdefault('upsert', False)
update.setdefault('multi', False)
self.match_result(command, actual.command)
if database_name:
self.test.assertEqual(
database_name, actual.database_name)
self.assertHasServiceId(spec, actual)
elif name == 'commandSucceededEvent':
self.test.assertIsInstance(actual, CommandSucceededEvent)
reply = spec.get('reply')
if reply:
self.match_result(reply, actual.reply)
self.assertHasServiceId(spec, actual)
elif name == 'commandFailedEvent':
self.test.assertIsInstance(actual, CommandFailedEvent)
self.assertHasServiceId(spec, actual)
elif name == 'poolCreatedEvent':
self.test.assertIsInstance(actual, PoolCreatedEvent)
elif name == 'poolReadyEvent':
# PyMongo 3.X does not support PoolReadyEvent.
assert False
elif name == 'poolClearedEvent':
self.test.assertIsInstance(actual, PoolClearedEvent)
self.assertHasServiceId(spec, actual)
elif name == 'poolClosedEvent':
self.test.assertIsInstance(actual, PoolClosedEvent)
elif name == 'connectionCreatedEvent':
self.test.assertIsInstance(actual, ConnectionCreatedEvent)
elif name == 'connectionReadyEvent':
self.test.assertIsInstance(actual, ConnectionReadyEvent)
elif name == 'connectionClosedEvent':
self.test.assertIsInstance(actual, ConnectionClosedEvent)
if 'reason' in spec:
self.test.assertEqual(actual.reason, spec['reason'])
elif name == 'connectionCheckOutStartedEvent':
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
elif name == 'connectionCheckOutFailedEvent':
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
if 'reason' in spec:
self.test.assertEqual(actual.reason, spec['reason'])
elif name == 'connectionCheckedOutEvent':
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
elif name == 'connectionCheckedInEvent':
self.test.assertIsInstance(actual, ConnectionCheckedInEvent)
else:
self.test.fail(
'Unsupported event type %s' % (name,))
def coerce_result(opname, result):
"""Convert a pymongo result into the spec's result format."""
if hasattr(result, 'acknowledged') and not result.acknowledged:
return {'acknowledged': False}
if opname == 'bulkWrite':
return parse_bulk_write_result(result)
if opname == 'insertOne':
return {'insertedId': result.inserted_id}
if opname == 'insertMany':
return {idx: _id for idx, _id in enumerate(result.inserted_ids)}
if opname in ('deleteOne', 'deleteMany'):
return {'deletedCount': result.deleted_count}
if opname in ('updateOne', 'updateMany', 'replaceOne'):
return {
'matchedCount': result.matched_count,
'modifiedCount': result.modified_count,
'upsertedCount': 0 if result.upserted_id is None else 1,
}
return result
class UnifiedSpecTestMixinV1(IntegrationTest):
"""Mixin class to run test cases from test specification files.
Assumes that tests conform to the `unified test format
<https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst>`_.
Specification of the test suite being currently run is available as
a class attribute ``TEST_SPEC``.
"""
SCHEMA_VERSION = Version.from_string('1.5')
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True
@staticmethod
def should_run_on(run_on_spec):
if not run_on_spec:
# Always run these tests.
return True
for req in run_on_spec:
if is_run_on_requirement_satisfied(req):
return True
return False
def insert_initial_data(self, initial_data):
for collection_data in initial_data:
coll_name = collection_data['collectionName']
db_name = collection_data['databaseName']
documents = collection_data['documents']
coll = self.client.get_database(db_name).get_collection(
coll_name, write_concern=WriteConcern(w="majority"))
coll.drop()
if len(documents) > 0:
coll.insert_many(documents)
else:
# ensure collection exists
result = coll.insert_one({})
coll.delete_one({'_id': result.inserted_id})
@classmethod
def setUpClass(cls):
# super call creates internal client cls.client
super(UnifiedSpecTestMixinV1, cls).setUpClass()
# process file-level runOnRequirements
run_on_spec = cls.TEST_SPEC.get('runOnRequirements', [])
if not cls.should_run_on(run_on_spec):
raise unittest.SkipTest(
'%s runOnRequirements not satisfied' % (cls.__name__,))
# add any special-casing for skipping tests here
if client_context.storage_engine == 'mmapv1':
if 'retryable-writes' in cls.TEST_SPEC['description']:
raise unittest.SkipTest(
"MMAPv1 does not support retryWrites=True")
def setUp(self):
super(UnifiedSpecTestMixinV1, self).setUp()
# process schemaVersion
# note: we check major schema version during class generation
# note: we do this here because we cannot run assertions in setUpClass
version = Version.from_string(self.TEST_SPEC['schemaVersion'])
self.assertLessEqual(
version, self.SCHEMA_VERSION,
'expected schema version %s or lower, got %s' % (
self.SCHEMA_VERSION, version))
# initialize internals
self.match_evaluator = MatchEvaluatorUtil(self)
def maybe_skip_test(self, spec):
# add any special-casing for skipping tests here
if client_context.storage_engine == 'mmapv1':
if 'Dirty explicit session is discarded' in spec['description']:
raise unittest.SkipTest(
"MMAPv1 does not support retryWrites=True")
elif 'Client side error in command starting transaction' in spec['description']:
raise unittest.SkipTest("Implement PYTHON-1894")
def process_error(self, exception, spec):
is_error = spec.get('isError')
is_client_error = spec.get('isClientError')
error_contains = spec.get('errorContains')
error_code = spec.get('errorCode')
error_code_name = spec.get('errorCodeName')
error_labels_contain = spec.get('errorLabelsContain')
error_labels_omit = spec.get('errorLabelsOmit')
expect_result = spec.get('expectResult')
if is_error:
# already satisfied because exception was raised
pass
if is_client_error:
# Connection errors are considered client errors.
if isinstance(exception, ConnectionFailure):
self.assertNotIsInstance(exception, NotPrimaryError)
elif isinstance(exception, (InvalidOperation, ConfigurationError)):
pass
else:
self.assertNotIsInstance(exception, PyMongoError)
if error_contains:
if isinstance(exception, BulkWriteError):
errmsg = str(exception.details).lower()
else:
errmsg = str(exception).lower()
self.assertIn(error_contains.lower(), errmsg)
if error_code:
self.assertEqual(
error_code, exception.details.get('code'))
if error_code_name:
self.assertEqual(
error_code_name, exception.details.get('codeName'))
if error_labels_contain:
labels = [err_label for err_label in error_labels_contain
if exception.has_error_label(err_label)]
self.assertEqual(labels, error_labels_contain)
if error_labels_omit:
for err_label in error_labels_omit:
if exception.has_error_label(err_label):
self.fail("Exception '%s' unexpectedly had label '%s'" % (
exception, err_label))
if expect_result:
if isinstance(exception, BulkWriteError):
result = parse_bulk_write_error_result(
exception)
self.match_evaluator.match_result(expect_result, result)
else:
self.fail("expectResult can only be specified with %s "
"exceptions" % (BulkWriteError,))
def __raise_if_unsupported(self, opname, target, *target_types):
if not isinstance(target, target_types):
self.fail('Operation %s not supported for entity '
'of type %s' % (opname, type(target)))
def __entityOperation_createChangeStream(self, target, *args, **kwargs):
if client_context.storage_engine == 'mmapv1':
self.skipTest("MMAPv1 does not support change streams")
self.__raise_if_unsupported(
'createChangeStream', target, MongoClient, Database, Collection)
stream = target.watch(*args, **kwargs)
self.addCleanup(stream.close)
return stream
def _clientOperation_createChangeStream(self, target, *args, **kwargs):
return self.__entityOperation_createChangeStream(
target, *args, **kwargs)
def _databaseOperation_createChangeStream(self, target, *args, **kwargs):
return self.__entityOperation_createChangeStream(
target, *args, **kwargs)
def _collectionOperation_createChangeStream(self, target, *args, **kwargs):
return self.__entityOperation_createChangeStream(
target, *args, **kwargs)
def _databaseOperation_runCommand(self, target, **kwargs):
self.__raise_if_unsupported('runCommand', target, Database)
# Ensure the first key is the command name.
ordered_command = SON([(kwargs.pop('command_name'), 1)])
ordered_command.update(kwargs['command'])
kwargs['command'] = ordered_command
return target.command(**kwargs)
def _databaseOperation_listCollections(self, target, *args, **kwargs):
if 'batch_size' in kwargs:
kwargs['cursor'] = {'batchSize': kwargs.pop('batch_size')}
cursor = target.list_collections(*args, **kwargs)
return list(cursor)
def __entityOperation_aggregate(self, target, *args, **kwargs):
self.__raise_if_unsupported('aggregate', target, Database, Collection)
return list(target.aggregate(*args, **kwargs))
def _databaseOperation_aggregate(self, target, *args, **kwargs):
return self.__entityOperation_aggregate(target, *args, **kwargs)
def _collectionOperation_aggregate(self, target, *args, **kwargs):
return self.__entityOperation_aggregate(target, *args, **kwargs)
def _collectionOperation_find(self, target, *args, **kwargs):
self.__raise_if_unsupported('find', target, Collection)
find_cursor = target.find(*args, **kwargs)
return list(find_cursor)
def _collectionOperation_createFindCursor(self, target, *args, **kwargs):
self.__raise_if_unsupported('find', target, Collection)
if 'filter' not in kwargs:
self.fail('createFindCursor requires a "filter" argument')
cursor = NonLazyCursor(target.find(*args, **kwargs))
self.addCleanup(cursor.close)
return cursor
def _collectionOperation_listIndexes(self, target, *args, **kwargs):
if 'batch_size' in kwargs:
self.skipTest('PyMongo does not support batch_size for '
'list_indexes')
return target.list_indexes(*args, **kwargs)
def _sessionOperation_withTransaction(self, target, *args, **kwargs):
if client_context.storage_engine == 'mmapv1':
self.skipTest('MMAPv1 does not support document-level locking')
self.__raise_if_unsupported('withTransaction', target, ClientSession)
return target.with_transaction(*args, **kwargs)
def _sessionOperation_startTransaction(self, target, *args, **kwargs):
if client_context.storage_engine == 'mmapv1':
self.skipTest('MMAPv1 does not support document-level locking')
self.__raise_if_unsupported('startTransaction', target, ClientSession)
return target.start_transaction(*args, **kwargs)
def _changeStreamOperation_iterateUntilDocumentOrError(self, target,
*args, **kwargs):
self.__raise_if_unsupported(
'iterateUntilDocumentOrError', target, ChangeStream)
return next(target)
def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs):
self.__raise_if_unsupported(
'iterateUntilDocumentOrError', target, NonLazyCursor)
return next(target)
def _cursor_close(self, target, *args, **kwargs):
self.__raise_if_unsupported('close', target, NonLazyCursor)
return target.close()
def run_entity_operation(self, spec):
target = self.entity_map[spec['object']]
opname = spec['name']
opargs = spec.get('arguments')
expect_error = spec.get('expectError')
save_as_entity = spec.get('saveResultAsEntity')
expect_result = spec.get('expectResult')
ignore = spec.get('ignoreResultAndError')
if ignore and (expect_error or save_as_entity or expect_result):
raise ValueError(
'ignoreResultAndError is incompatible with saveResultAsEntity'
', expectError, and expectResult')
if opargs:
arguments = parse_spec_options(copy.deepcopy(opargs))
prepare_spec_arguments(spec, arguments, camel_to_snake(opname),
self.entity_map, self.run_operations)
else:
arguments = tuple()
if isinstance(target, MongoClient):
method_name = '_clientOperation_%s' % (opname,)
elif isinstance(target, Database):
method_name = '_databaseOperation_%s' % (opname,)
elif isinstance(target, Collection):
method_name = '_collectionOperation_%s' % (opname,)
elif isinstance(target, ChangeStream):
method_name = '_changeStreamOperation_%s' % (opname,)
elif isinstance(target, NonLazyCursor):
method_name = '_cursor_%s' % (opname,)
elif isinstance(target, ClientSession):
method_name = '_sessionOperation_%s' % (opname,)
elif isinstance(target, GridFSBucket):
raise NotImplementedError
else:
method_name = 'doesNotExist'
try:
method = getattr(self, method_name)
except AttributeError:
try:
cmd = getattr(target, camel_to_snake(opname))
except AttributeError:
self.fail('Unsupported operation %s on entity %s' % (
opname, target))
else:
cmd = functools.partial(method, target)
try:
result = cmd(**dict(arguments))
except Exception as exc:
# Ignore all operation errors but to avoid masking bugs don't
# ignore things like TypeError and ValueError.
if ignore and isinstance(exc, (PyMongoError,)):
return
if expect_error:
return self.process_error(exc, expect_error)
raise
else:
if expect_error:
self.fail('Excepted error %s but "%s" succeeded: %s' % (
expect_error, opname, result))
if expect_result:
actual = coerce_result(opname, result)
self.match_evaluator.match_result(expect_result, actual)
if save_as_entity:
self.entity_map[save_as_entity] = result
def __set_fail_point(self, client, command_args):
if not client_context.test_commands_enabled:
self.skipTest('Test commands must be enabled')
cmd_on = SON([('configureFailPoint', 'failCommand')])
cmd_on.update(command_args)
client.admin.command(cmd_on)
self.addCleanup(
client.admin.command,
'configureFailPoint', cmd_on['configureFailPoint'], mode='off')
def _testOperation_failPoint(self, spec):
self.__set_fail_point(
client=self.entity_map[spec['client']],
command_args=spec['failPoint'])
def _testOperation_targetedFailPoint(self, spec):
session = self.entity_map[spec['session']]
if not session._pinned_address:
self.fail("Cannot use targetedFailPoint operation with unpinned "
"session %s" % (spec['session'],))
client = single_client('%s:%s' % session._pinned_address)
self.addCleanup(client.close)
self.__set_fail_point(
client=client, command_args=spec['failPoint'])
def _testOperation_assertSessionTransactionState(self, spec):
session = self.entity_map[spec['session']]
expected_state = getattr(_TxnState, spec['state'].upper())
self.assertEqual(expected_state, session._transaction.state)
def _testOperation_assertSessionPinned(self, spec):
session = self.entity_map[spec['session']]
self.assertIsNotNone(session._transaction.pinned_address)
def _testOperation_assertSessionUnpinned(self, spec):
session = self.entity_map[spec['session']]
self.assertIsNone(session._pinned_address)
self.assertIsNone(session._transaction.pinned_address)
def __get_last_two_command_lsids(self, listener):
cmd_started_events = []
for event in reversed(listener.events):
if isinstance(event, CommandStartedEvent):
cmd_started_events.append(event)
if len(cmd_started_events) < 2:
self.fail('Needed 2 CommandStartedEvents to compare lsids, '
'got %s' % (len(cmd_started_events)))
return tuple([e.command['lsid'] for e in cmd_started_events][:2])
def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec):
listener = self.entity_map.get_listener_for_client(spec['client'])
self.assertNotEqual(*self.__get_last_two_command_lsids(listener))
def _testOperation_assertSameLsidOnLastTwoCommands(self, spec):
listener = self.entity_map.get_listener_for_client(spec['client'])
self.assertEqual(*self.__get_last_two_command_lsids(listener))
def _testOperation_assertSessionDirty(self, spec):
session = self.entity_map[spec['session']]
self.assertTrue(session._server_session.dirty)
def _testOperation_assertSessionNotDirty(self, spec):
session = self.entity_map[spec['session']]
return self.assertFalse(session._server_session.dirty)
def _testOperation_assertCollectionExists(self, spec):
database_name = spec['databaseName']
collection_name = spec['collectionName']
collection_name_list = list(
self.client.get_database(database_name).list_collection_names())
self.assertIn(collection_name, collection_name_list)
def _testOperation_assertCollectionNotExists(self, spec):
database_name = spec['databaseName']
collection_name = spec['collectionName']
collection_name_list = list(
self.client.get_database(database_name).list_collection_names())
self.assertNotIn(collection_name, collection_name_list)
def _testOperation_assertIndexExists(self, spec):
collection = self.client[spec['databaseName']][spec['collectionName']]
index_names = [idx['name'] for idx in collection.list_indexes()]
self.assertIn(spec['indexName'], index_names)
def _testOperation_assertIndexNotExists(self, spec):
collection = self.client[spec['databaseName']][spec['collectionName']]
for index in collection.list_indexes():
self.assertNotEqual(spec['indexName'], index['name'])
def _testOperation_assertNumberConnectionsCheckedOut(self, spec):
client = self.entity_map[spec['client']]
pool = get_pool(client)
self.assertEqual(spec['connections'], pool.active_sockets)
def run_special_operation(self, spec):
opname = spec['name']
method_name = '_testOperation_%s' % (opname,)
try:
method = getattr(self, method_name)
except AttributeError:
self.fail('Unsupported special test operation %s' % (opname,))
else:
method(spec['arguments'])
def run_operations(self, spec):
for op in spec:
target = op['object']
if target != 'testRunner':
self.run_entity_operation(op)
else:
self.run_special_operation(op)
def check_events(self, spec):
for event_spec in spec:
client_name = event_spec['client']
events = event_spec['events']
# Valid types: 'command', 'cmap'
event_type = event_spec.get('eventType', 'command')
assert event_type in ('command', 'cmap')
listener = self.entity_map.get_listener_for_client(client_name)
actual_events = listener.get_events(event_type)
if len(events) == 0:
self.assertEqual(actual_events, [])
continue
if len(events) > len(actual_events):
self.fail('Expected to see %s events, got %s' % (
len(events), len(actual_events)))
for idx, expected_event in enumerate(events):
self.match_evaluator.match_event(
event_type, expected_event, actual_events[idx])
def verify_outcome(self, spec):
for collection_data in spec:
coll_name = collection_data['collectionName']
db_name = collection_data['databaseName']
expected_documents = collection_data['documents']
coll = self.client.get_database(db_name).get_collection(
coll_name,
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern(level='local'))
if expected_documents:
sorted_expected_documents = sorted(
expected_documents, key=lambda doc: doc['_id'])
actual_documents = list(
coll.find({}, sort=[('_id', ASCENDING)]))
self.assertListEqual(sorted_expected_documents,
actual_documents)
def run_scenario(self, spec):
# maybe skip test manually
self.maybe_skip_test(spec)
# process test-level runOnRequirements
run_on_spec = spec.get('runOnRequirements', [])
if not self.should_run_on(run_on_spec):
raise unittest.SkipTest('runOnRequirements not satisfied')
# process skipReason
skip_reason = spec.get('skipReason', None)
if skip_reason is not None:
raise unittest.SkipTest('%s' % (skip_reason,))
# process createEntities
self.entity_map = EntityMapUtil(self)
self.entity_map.create_entities_from_spec(
self.TEST_SPEC.get('createEntities', []))
# process initialData
self.insert_initial_data(self.TEST_SPEC.get('initialData', []))
# process operations
self.run_operations(spec['operations'])
# process expectEvents
if 'expectEvents' in spec:
expect_events = spec['expectEvents']
self.assertTrue(expect_events, 'expectEvents must be non-empty')
self.check_events(expect_events)
# process outcome
self.verify_outcome(spec.get('outcome', []))
class UnifiedSpecTestMeta(type):
"""Metaclass for generating test classes."""
def __init__(cls, *args, **kwargs):
super(UnifiedSpecTestMeta, cls).__init__(*args, **kwargs)
def create_test(spec):
def test_case(self):
self.run_scenario(spec)
return test_case
for test_spec in cls.TEST_SPEC['tests']:
description = test_spec['description']
test_name = 'test_%s' % (description.strip('. ').
replace(' ', '_').replace('.', '_'),)
test_method = create_test(copy.deepcopy(test_spec))
test_method.__name__ = str(test_name)
for fail_pattern in cls.EXPECTED_FAILURES:
if re.search(fail_pattern, description):
test_method = unittest.expectedFailure(test_method)
break
setattr(cls, test_name, test_method)
_ALL_MIXIN_CLASSES = [
UnifiedSpecTestMixinV1,
# add mixin classes for new schema major versions here
]
_SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = {
KLASS.SCHEMA_VERSION[0]: KLASS for KLASS in _ALL_MIXIN_CLASSES}
def generate_test_classes(test_path, module=__name__, class_name_prefix='',
expected_failures=[],
bypass_test_generation_errors=False,
**kwargs):
"""Method for generating test classes. Returns a dictionary where keys are
the names of test classes and values are the test class objects."""
test_klasses = {}
def test_base_class_factory(test_spec):
"""Utility that creates the base class to use for test generation.
This is needed to ensure that cls.TEST_SPEC is appropriately set when
the metaclass __init__ is invoked."""
class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)):
TEST_SPEC = test_spec
EXPECTED_FAILURES = expected_failures
return SpecTestBase
for dirpath, _, filenames in os.walk(test_path):
dirname = os.path.split(dirpath)[-1]
for filename in filenames:
fpath = os.path.join(dirpath, filename)
with open(fpath) as scenario_stream:
# Use tz_aware=False to match how CodecOptions decodes
# dates.
opts = json_util.JSONOptions(tz_aware=False)
scenario_def = json_util.loads(
scenario_stream.read(), json_options=opts)
test_type = os.path.splitext(filename)[0]
snake_class_name = 'Test%s_%s_%s' % (
class_name_prefix, dirname.replace('-', '_'),
test_type.replace('-', '_').replace('.', '_'))
class_name = snake_to_camel(snake_class_name)
try:
schema_version = Version.from_string(
scenario_def['schemaVersion'])
mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get(
schema_version[0])
if mixin_class is None:
raise ValueError(
"test file '%s' has unsupported schemaVersion '%s'" % (
fpath, schema_version))
module_dict = {'__module__': module}
module_dict.update(kwargs)
test_klasses[class_name] = type(
class_name,
(mixin_class, test_base_class_factory(scenario_def),),
module_dict)
except Exception:
if bypass_test_generation_errors:
continue
raise
return test_klasses