Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pymongo / test / test_change_stream.py
Size: Mime:
# Copyright 2017 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test the change_stream module."""

import random
import os
import re
import sys
import string
import threading
import time
import uuid

from contextlib import contextmanager
from itertools import product

sys.path[0:0] = ['']

from bson import ObjectId, SON, Timestamp, encode, json_util
from bson.binary import (ALL_UUID_REPRESENTATIONS,
                         Binary,
                         STANDARD,
                         PYTHON_LEGACY)
from bson.py3compat import iteritems
from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument

from pymongo import MongoClient
from pymongo.command_cursor import CommandCursor
from pymongo.errors import (InvalidOperation, OperationFailure,
                            ServerSelectionTimeoutError)
from pymongo.message import _CursorAddress
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern

from test import client_context, unittest, IntegrationTest
from test.unified_format import generate_test_classes
from test.utils import (
    EventListener, AllowListEventListener, rs_or_single_client, wait_until)


class TestChangeStreamBase(IntegrationTest):
    RUN_ON_LOAD_BALANCER = True

    def change_stream_with_client(self, client, *args, **kwargs):
        """Create a change stream using the given client and return it."""
        raise NotImplementedError

    def change_stream(self, *args, **kwargs):
        """Create a change stream using the default client and return it."""
        return self.change_stream_with_client(self.client, *args, **kwargs)

    def client_with_listener(self, *commands):
        """Return a client with a AllowListEventListener."""
        listener = AllowListEventListener(*commands)
        client = rs_or_single_client(event_listeners=[listener])
        self.addCleanup(client.close)
        return client, listener

    def watched_collection(self, *args, **kwargs):
        """Return a collection that is watched by self.change_stream()."""
        # Construct a unique collection for each test.
        collname = '.'.join(self.id().rsplit('.', 2)[1:])
        return self.db.get_collection(collname, *args, **kwargs)

    def generate_invalidate_event(self, change_stream):
        """Cause a change stream invalidate event."""
        raise NotImplementedError

    def generate_unique_collnames(self, numcolls):
        """Generate numcolls collection names unique to a test."""
        collnames = []
        for idx in range(1, numcolls + 1):
            collnames.append(self.id() + '_' + str(idx))
        return collnames

    def get_resume_token(self, invalidate=False):
        """Get a resume token to use for starting a change stream."""
        # Ensure targeted collection exists before starting.
        coll = self.watched_collection(write_concern=WriteConcern('majority'))
        coll.insert_one({})

        if invalidate:
            with self.change_stream(
                    [{'$match': {'operationType': 'invalidate'}}]) as cs:
                if isinstance(cs._target, MongoClient):
                    self.skipTest(
                        "cluster-level change streams cannot be invalidated")
                self.generate_invalidate_event(cs)
                return cs.next()['_id']
        else:
            with self.change_stream() as cs:
                coll.insert_one({'data': 1})
                return cs.next()['_id']

    def get_start_at_operation_time(self):
        """Get an operationTime. Advances the operation clock beyond the most
        recently returned timestamp."""
        optime = self.client.admin.command("ping")["operationTime"]
        return Timestamp(optime.time, optime.inc + 1)

    def insert_one_and_check(self, change_stream, doc):
        """Insert a document and check that it shows up in the change stream."""
        raise NotImplementedError

    def kill_change_stream_cursor(self, change_stream):
        """Cause a cursor not found error on the next getMore."""
        cursor = change_stream._cursor
        address = _CursorAddress(cursor.address, cursor._CommandCursor__ns)
        client = self.watched_collection().database.client
        client._close_cursor_now(cursor.cursor_id, address)


class APITestsMixin(object):
    def test_watch(self):
        with self.change_stream(
                [{'$project': {'foo': 0}}], full_document='updateLookup',
                max_await_time_ms=1000, batch_size=100) as change_stream:
            self.assertEqual([{'$project': {'foo': 0}}],
                             change_stream._pipeline)
            self.assertEqual('updateLookup', change_stream._full_document)
            self.assertEqual(1000, change_stream._max_await_time_ms)
            self.assertEqual(100, change_stream._batch_size)
            self.assertIsInstance(change_stream._cursor, CommandCursor)
            self.assertEqual(
                1000, change_stream._cursor._CommandCursor__max_await_time_ms)
            self.watched_collection(
                write_concern=WriteConcern("majority")).insert_one({})
            _ = change_stream.next()
            resume_token = change_stream.resume_token
        with self.assertRaises(TypeError):
            self.change_stream(pipeline={})
        with self.assertRaises(TypeError):
            self.change_stream(full_document={})
        # No Error.
        with self.change_stream(resume_after=resume_token):
            pass

    def test_try_next(self):
        # ChangeStreams only read majority committed data so use w:majority.
        coll = self.watched_collection().with_options(
            write_concern=WriteConcern("majority"))
        coll.drop()
        coll.insert_one({})
        self.addCleanup(coll.drop)
        with self.change_stream(max_await_time_ms=250) as stream:
            self.assertIsNone(stream.try_next())        # No changes initially.
            coll.insert_one({})                         # Generate a change.
            # On sharded clusters, even majority-committed changes only show
            # up once an event that sorts after it shows up on the other
            # shard. So, we wait on try_next to eventually return changes.
            wait_until(lambda: stream.try_next() is not None,
                       "get change from try_next")

    def test_try_next_runs_one_getmore(self):
        listener = EventListener()
        client = rs_or_single_client(event_listeners=[listener])
        # Connect to the cluster.
        client.admin.command('ping')
        listener.results.clear()
        # ChangeStreams only read majority committed data so use w:majority.
        coll = self.watched_collection().with_options(
            write_concern=WriteConcern("majority"))
        coll.drop()
        # Create the watched collection before starting the change stream to
        # skip any "create" events.
        coll.insert_one({'_id': 1})
        self.addCleanup(coll.drop)
        with self.change_stream_with_client(
                client, max_await_time_ms=250) as stream:
            self.assertEqual(listener.started_command_names(), ["aggregate"])
            listener.results.clear()

            # Confirm that only a single getMore is run even when no documents
            # are returned.
            self.assertIsNone(stream.try_next())
            self.assertEqual(listener.started_command_names(), ["getMore"])
            listener.results.clear()
            self.assertIsNone(stream.try_next())
            self.assertEqual(listener.started_command_names(), ["getMore"])
            listener.results.clear()

            # Get at least one change before resuming.
            coll.insert_one({'_id': 2})
            wait_until(lambda: stream.try_next() is not None,
                       "get change from try_next")
            listener.results.clear()

            # Cause the next request to initiate the resume process.
            self.kill_change_stream_cursor(stream)
            listener.results.clear()

            # The sequence should be:
            # - getMore, fail
            # - resume with aggregate command
            # - no results, return immediately without another getMore
            self.assertIsNone(stream.try_next())
            self.assertEqual(
                listener.started_command_names(), ["getMore", "aggregate"])
            listener.results.clear()

            # Stream still works after a resume.
            coll.insert_one({'_id': 3})
            wait_until(lambda: stream.try_next() is not None,
                       "get change from try_next")
            self.assertEqual(set(listener.started_command_names()),
                             set(["getMore"]))
            self.assertIsNone(stream.try_next())

    def test_batch_size_is_honored(self):
        listener = EventListener()
        client = rs_or_single_client(event_listeners=[listener])
        # Connect to the cluster.
        client.admin.command('ping')
        listener.results.clear()
        # ChangeStreams only read majority committed data so use w:majority.
        coll = self.watched_collection().with_options(
            write_concern=WriteConcern("majority"))
        coll.drop()
        # Create the watched collection before starting the change stream to
        # skip any "create" events.
        coll.insert_one({'_id': 1})
        self.addCleanup(coll.drop)
        # Expected batchSize.
        expected = {'batchSize': 23}
        with self.change_stream_with_client(
                client, max_await_time_ms=250, batch_size=23) as stream:
            # Confirm that batchSize is honored for initial batch.
            cmd = listener.results['started'][0].command
            self.assertEqual(cmd['cursor'], expected)
            listener.results.clear()
            # Confirm that batchSize is honored by getMores.
            self.assertIsNone(stream.try_next())
            cmd = listener.results['started'][0].command
            key = next(iter(expected))
            self.assertEqual(expected[key], cmd[key])

    # $changeStream.startAtOperationTime was added in 4.0.0.
    @client_context.require_version_min(4, 0, 0)
    def test_start_at_operation_time(self):
        optime = self.get_start_at_operation_time()

        coll = self.watched_collection(
            write_concern=WriteConcern("majority"))
        ndocs = 3
        coll.insert_many([{"data": i} for i in range(ndocs)])

        with self.change_stream(start_at_operation_time=optime) as cs:
            for i in range(ndocs):
                cs.next()

    def _test_full_pipeline(self, expected_cs_stage):
        client, listener = self.client_with_listener("aggregate")
        results = listener.results
        with self.change_stream_with_client(
                client, [{'$project': {'foo': 0}}]) as _:
            pass

        self.assertEqual(1, len(results['started']))
        command = results['started'][0]
        self.assertEqual('aggregate', command.command_name)
        self.assertEqual([
            {'$changeStream': expected_cs_stage},
            {'$project': {'foo': 0}}],
            command.command['pipeline'])

    def test_full_pipeline(self):
        """$changeStream must be the first stage in a change stream pipeline
        sent to the server.
        """
        self._test_full_pipeline({})

    def test_iteration(self):
        with self.change_stream(batch_size=2) as change_stream:
            num_inserted = 10
            self.watched_collection().insert_many(
                [{} for _ in range(num_inserted)])
            inserts_received = 0
            for change in change_stream:
                self.assertEqual(change['operationType'], 'insert')
                inserts_received += 1
                if inserts_received == num_inserted:
                    break
            self._test_invalidate_stops_iteration(change_stream)

    def _test_next_blocks(self, change_stream):
        inserted_doc = {'_id': ObjectId()}
        changes = []
        t = threading.Thread(
            target=lambda: changes.append(change_stream.next()))
        t.start()
        # Sleep for a bit to prove that the call to next() blocks.
        time.sleep(1)
        self.assertTrue(t.is_alive())
        self.assertFalse(changes)
        self.watched_collection().insert_one(inserted_doc)
        # Join with large timeout to give the server time to return the change,
        # in particular for shard clusters.
        t.join(30)
        self.assertFalse(t.is_alive())
        self.assertEqual(1, len(changes))
        self.assertEqual(changes[0]['operationType'], 'insert')
        self.assertEqual(changes[0]['fullDocument'], inserted_doc)

    def test_next_blocks(self):
        """Test that next blocks until a change is readable"""
        # Use a short await time to speed up the test.
        with self.change_stream(max_await_time_ms=250) as change_stream:
            self._test_next_blocks(change_stream)

    def test_aggregate_cursor_blocks(self):
        """Test that an aggregate cursor blocks until a change is readable."""
        with self.watched_collection().aggregate(
                [{'$changeStream': {}}], maxAwaitTimeMS=250) as change_stream:
            self._test_next_blocks(change_stream)

    def test_concurrent_close(self):
        """Ensure a ChangeStream can be closed from another thread."""
        # Use a short await time to speed up the test.
        with self.change_stream(max_await_time_ms=250) as change_stream:
            def iterate_cursor():
                for _ in change_stream:
                    pass
            t = threading.Thread(target=iterate_cursor)
            t.start()
            self.watched_collection().insert_one({})
            time.sleep(1)
            change_stream.close()
            t.join(3)
            self.assertFalse(t.is_alive())

    def test_unknown_full_document(self):
        """Must rely on the server to raise an error on unknown fullDocument.
        """
        try:
            with self.change_stream(full_document='notValidatedByPyMongo'):
                pass
        except OperationFailure:
            pass

    def test_change_operations(self):
        """Test each operation type."""
        expected_ns = {'db': self.watched_collection().database.name,
                       'coll': self.watched_collection().name}
        with self.change_stream() as change_stream:
            # Insert.
            inserted_doc = {'_id': ObjectId(), 'foo': 'bar'}
            self.watched_collection().insert_one(inserted_doc)
            change = change_stream.next()
            self.assertTrue(change['_id'])
            self.assertEqual(change['operationType'], 'insert')
            self.assertEqual(change['ns'], expected_ns)
            self.assertEqual(change['fullDocument'], inserted_doc)
            # Update.
            update_spec = {'$set': {'new': 1}, '$unset': {'foo': 1}}
            self.watched_collection().update_one(inserted_doc, update_spec)
            change = change_stream.next()
            self.assertTrue(change['_id'])
            self.assertEqual(change['operationType'], 'update')
            self.assertEqual(change['ns'], expected_ns)
            self.assertNotIn('fullDocument', change)

            expected_update_description = {
                'updatedFields': {'new': 1},
                'removedFields': ['foo']}
            if client_context.version.at_least(4, 5, 0):
                expected_update_description['truncatedArrays'] = []
            self.assertEqual(expected_update_description,
                             change['updateDescription'])
            # Replace.
            self.watched_collection().replace_one({'new': 1}, {'foo': 'bar'})
            change = change_stream.next()
            self.assertTrue(change['_id'])
            self.assertEqual(change['operationType'], 'replace')
            self.assertEqual(change['ns'], expected_ns)
            self.assertEqual(change['fullDocument'], inserted_doc)
            # Delete.
            self.watched_collection().delete_one({'foo': 'bar'})
            change = change_stream.next()
            self.assertTrue(change['_id'])
            self.assertEqual(change['operationType'], 'delete')
            self.assertEqual(change['ns'], expected_ns)
            self.assertNotIn('fullDocument', change)
            # Invalidate.
            self._test_get_invalidate_event(change_stream)

    @client_context.require_version_min(4, 1, 1)
    def test_start_after(self):
        resume_token = self.get_resume_token(invalidate=True)

        # resume_after cannot resume after invalidate.
        with self.assertRaises(OperationFailure):
            self.change_stream(resume_after=resume_token)

        # start_after can resume after invalidate.
        with self.change_stream(start_after=resume_token) as change_stream:
            self.watched_collection().insert_one({'_id': 2})
            change = change_stream.next()
            self.assertEqual(change['operationType'], 'insert')
            self.assertEqual(change['fullDocument'], {'_id': 2})

    @client_context.require_version_min(4, 1, 1)
    def test_start_after_resume_process_with_changes(self):
        resume_token = self.get_resume_token(invalidate=True)

        with self.change_stream(start_after=resume_token,
                                max_await_time_ms=250) as change_stream:
            self.watched_collection().insert_one({'_id': 2})
            change = change_stream.next()
            self.assertEqual(change['operationType'], 'insert')
            self.assertEqual(change['fullDocument'], {'_id': 2})

            self.assertIsNone(change_stream.try_next())
            self.kill_change_stream_cursor(change_stream)

            self.watched_collection().insert_one({'_id': 3})
            change = change_stream.next()
            self.assertEqual(change['operationType'], 'insert')
            self.assertEqual(change['fullDocument'], {'_id': 3})

    @client_context.require_no_mongos  # Remove after SERVER-41196
    @client_context.require_version_min(4, 1, 1)
    def test_start_after_resume_process_without_changes(self):
        resume_token = self.get_resume_token(invalidate=True)

        with self.change_stream(start_after=resume_token,
                                max_await_time_ms=250) as change_stream:
            self.assertIsNone(change_stream.try_next())
            self.kill_change_stream_cursor(change_stream)

            self.watched_collection().insert_one({'_id': 2})
            change = change_stream.next()
            self.assertEqual(change['operationType'], 'insert')
            self.assertEqual(change['fullDocument'], {'_id': 2})


class ProseSpecTestsMixin(object):
    def _client_with_listener(self, *commands):
        listener = AllowListEventListener(*commands)
        client = rs_or_single_client(event_listeners=[listener])
        self.addCleanup(client.close)
        return client, listener

    def _populate_and_exhaust_change_stream(self, change_stream, batch_size=3):
        self.watched_collection().insert_many(
            [{"data": k} for k in range(batch_size)])
        for _ in range(batch_size):
            change = next(change_stream)
        return change

    def _get_expected_resume_token_legacy(self, stream,
                                          listener, previous_change=None):
        """Predicts what the resume token should currently be for server
        versions that don't support postBatchResumeToken. Assumes the stream
        has never returned any changes if previous_change is None."""
        if previous_change is None:
            agg_cmd = listener.results['started'][0]
            stage = agg_cmd.command["pipeline"][0]["$changeStream"]
            return stage.get("resumeAfter") or stage.get("startAfter")

        return previous_change['_id']

    def _get_expected_resume_token(self, stream, listener,
                                   previous_change=None):
        """Predicts what the resume token should currently be for server
        versions that support postBatchResumeToken. Assumes the stream has
        never returned any changes if previous_change is None. Assumes
        listener is a AllowListEventListener that listens for aggregate and
        getMore commands."""
        if previous_change is None or stream._cursor._has_next():
            token = self._get_expected_resume_token_legacy(
                stream, listener, previous_change)
            if token is not None:
                return token

        response = listener.results['succeeded'][-1].reply
        return response['cursor']['postBatchResumeToken']

    def _test_raises_error_on_missing_id(self, expected_exception):
        """ChangeStream will raise an exception if the server response is
        missing the resume token.
        """
        with self.change_stream([{'$project': {'_id': 0}}]) as change_stream:
            self.watched_collection().insert_one({})
            with self.assertRaises(expected_exception):
                next(change_stream)
            # The cursor should now be closed.
            with self.assertRaises(StopIteration):
                next(change_stream)

    def _test_update_resume_token(self, expected_rt_getter):
        """ChangeStream must continuously track the last seen resumeToken."""
        client, listener = self._client_with_listener("aggregate", "getMore")
        coll = self.watched_collection(write_concern=WriteConcern('majority'))
        with self.change_stream_with_client(client) as change_stream:
            self.assertEqual(
                change_stream.resume_token,
                expected_rt_getter(change_stream, listener))
            for _ in range(3):
                coll.insert_one({})
                change = next(change_stream)
                self.assertEqual(
                    change_stream.resume_token,
                    expected_rt_getter(change_stream, listener, change))

    # Prose test no. 1
    @client_context.require_version_min(4, 0, 7)
    def test_update_resume_token(self):
        self._test_update_resume_token(self._get_expected_resume_token)

    # Prose test no. 1
    @client_context.require_version_max(4, 0, 7)
    def test_update_resume_token_legacy(self):
        self._test_update_resume_token(self._get_expected_resume_token_legacy)

    # Prose test no. 2
    @client_context.require_version_max(4, 3, 3)  # PYTHON-2120
    @client_context.require_version_min(4, 1, 8)
    def test_raises_error_on_missing_id_418plus(self):
        # Server returns an error on 4.1.8+
        self._test_raises_error_on_missing_id(OperationFailure)

    # Prose test no. 2
    @client_context.require_version_max(4, 1, 8)
    def test_raises_error_on_missing_id_418minus(self):
        # PyMongo raises an error
        self._test_raises_error_on_missing_id(InvalidOperation)

    # Prose test no. 3
    def test_resume_on_error(self):
        with self.change_stream() as change_stream:
            self.insert_one_and_check(change_stream, {'_id': 1})
            # Cause a cursor not found error on the next getMore.
            self.kill_change_stream_cursor(change_stream)
            self.insert_one_and_check(change_stream, {'_id': 2})

    # Prose test no. 4
    @client_context.require_failCommand_fail_point
    def test_no_resume_attempt_if_aggregate_command_fails(self):
        # Set non-retryable error on aggregate command.
        fail_point = {'mode': {'times': 1},
                      'data': {'errorCode': 2, 'failCommands': ['aggregate']}}
        client, listener = self._client_with_listener("aggregate", "getMore")
        with self.fail_point(fail_point):
            try:
                _ = self.change_stream_with_client(client)
            except OperationFailure:
                pass

        # Driver should have attempted aggregate command only once.
        self.assertEqual(len(listener.results['started']), 1)
        self.assertEqual(listener.results['started'][0].command_name,
                         'aggregate')

    # Prose test no. 5 - REMOVED
    # Prose test no. 6 - SKIPPED
    # Reason: readPreference is not configurable using the watch() helpers
    #   so we can skip this test. Also, PyMongo performs server selection for
    #   each operation which ensure compliance with this prose test.

    # Prose test no. 7
    def test_initial_empty_batch(self):
        with self.change_stream() as change_stream:
            # The first batch should be empty.
            self.assertFalse(change_stream._cursor._has_next())
            cursor_id = change_stream._cursor.cursor_id
            self.assertTrue(cursor_id)
            self.insert_one_and_check(change_stream, {})
            # Make sure we're still using the same cursor.
            self.assertEqual(cursor_id, change_stream._cursor.cursor_id)

    # Prose test no. 8
    def test_kill_cursors(self):
        def raise_error():
            raise ServerSelectionTimeoutError('mock error')
        with self.change_stream() as change_stream:
            self.insert_one_and_check(change_stream, {'_id': 1})
            # Cause a cursor not found error on the next getMore.
            cursor = change_stream._cursor
            self.kill_change_stream_cursor(change_stream)
            cursor.close = raise_error
            self.insert_one_and_check(change_stream, {'_id': 2})

    # Prose test no. 9
    @client_context.require_version_min(4, 0, 0)
    @client_context.require_version_max(4, 0, 7)
    def test_start_at_operation_time_caching(self):
        # Case 1: change stream not started with startAtOperationTime
        client, listener = self.client_with_listener("aggregate")
        with self.change_stream_with_client(client) as cs:
            self.kill_change_stream_cursor(cs)
            cs.try_next()
        cmd = listener.results['started'][-1].command
        self.assertIsNotNone(cmd["pipeline"][0]["$changeStream"].get(
            "startAtOperationTime"))

        # Case 2: change stream started with startAtOperationTime
        listener.results.clear()
        optime = self.get_start_at_operation_time()
        with self.change_stream_with_client(
                client, start_at_operation_time=optime) as cs:
            self.kill_change_stream_cursor(cs)
            cs.try_next()
        cmd = listener.results['started'][-1].command
        self.assertEqual(cmd["pipeline"][0]["$changeStream"].get(
            "startAtOperationTime"), optime, str([k.command for k in
                                                  listener.results['started']]))

    # Prose test no. 10 - SKIPPED
    # This test is identical to prose test no. 3.

    # Prose test no. 11
    @client_context.require_version_min(4, 0, 7)
    def test_resumetoken_empty_batch(self):
        client, listener = self._client_with_listener("getMore")
        with self.change_stream_with_client(client) as change_stream:
            self.assertIsNone(change_stream.try_next())
            resume_token = change_stream.resume_token

        response = listener.results['succeeded'][0].reply
        self.assertEqual(resume_token,
                         response["cursor"]["postBatchResumeToken"])

    # Prose test no. 11
    @client_context.require_version_min(4, 0, 7)
    def test_resumetoken_exhausted_batch(self):
        client, listener = self._client_with_listener("getMore")
        with self.change_stream_with_client(client) as change_stream:
            self._populate_and_exhaust_change_stream(change_stream)
            resume_token = change_stream.resume_token

        response = listener.results['succeeded'][-1].reply
        self.assertEqual(resume_token,
                         response["cursor"]["postBatchResumeToken"])

    # Prose test no. 12
    @client_context.require_version_max(4, 0, 7)
    def test_resumetoken_empty_batch_legacy(self):
        resume_point = self.get_resume_token()

        # Empty resume token when neither resumeAfter or startAfter specified.
        with self.change_stream() as change_stream:
            change_stream.try_next()
            self.assertIsNone(change_stream.resume_token)

        # Resume token value is same as resumeAfter.
        with self.change_stream(resume_after=resume_point) as change_stream:
            change_stream.try_next()
            resume_token = change_stream.resume_token
            self.assertEqual(resume_token, resume_point)

    # Prose test no. 12
    @client_context.require_version_max(4, 0, 7)
    def test_resumetoken_exhausted_batch_legacy(self):
        # Resume token is _id of last change.
        with self.change_stream() as change_stream:
            change = self._populate_and_exhaust_change_stream(change_stream)
            self.assertEqual(change_stream.resume_token, change["_id"])
            resume_point = change['_id']

        # Resume token is _id of last change even if resumeAfter is specified.
        with self.change_stream(resume_after=resume_point) as change_stream:
            change = self._populate_and_exhaust_change_stream(change_stream)
            self.assertEqual(change_stream.resume_token, change["_id"])

    # Prose test no. 13
    def test_resumetoken_partially_iterated_batch(self):
        # When batch has been iterated up to but not including the last element.
        # Resume token should be _id of previous change document.
        with self.change_stream() as change_stream:
            self.watched_collection(
                write_concern=WriteConcern('majority')).insert_many(
                [{"data": k} for k in range(3)])
            for _ in range(2):
                change = next(change_stream)
            resume_token = change_stream.resume_token

        self.assertEqual(resume_token, change["_id"])

    def _test_resumetoken_uniterated_nonempty_batch(self, resume_option):
        # When the batch is not empty and hasn't been iterated at all.
        # Resume token should be same as the resume option used.
        resume_point = self.get_resume_token()

        # Insert some documents so that firstBatch isn't empty.
        self.watched_collection(
            write_concern=WriteConcern("majority")).insert_many(
            [{'a': 1}, {'b': 2}, {'c': 3}])

        # Resume token should be same as the resume option.
        with self.change_stream(
                **{resume_option: resume_point}) as change_stream:
            self.assertTrue(change_stream._cursor._has_next())
            resume_token = change_stream.resume_token
        self.assertEqual(resume_token, resume_point)

    # Prose test no. 14
    @client_context.require_no_mongos
    def test_resumetoken_uniterated_nonempty_batch_resumeafter(self):
        self._test_resumetoken_uniterated_nonempty_batch("resume_after")

    # Prose test no. 14
    @client_context.require_no_mongos
    @client_context.require_version_min(4, 1, 1)
    def test_resumetoken_uniterated_nonempty_batch_startafter(self):
        self._test_resumetoken_uniterated_nonempty_batch("start_after")

    # Prose test no. 17
    @client_context.require_version_min(4, 1, 1)
    def test_startafter_resume_uses_startafter_after_empty_getMore(self):
        # Resume should use startAfter after no changes have been returned.
        resume_point = self.get_resume_token()

        client, listener = self._client_with_listener("aggregate")
        with self.change_stream_with_client(
                client, start_after=resume_point) as change_stream:
            self.assertFalse(change_stream._cursor._has_next())  # No changes
            change_stream.try_next()                        # No changes
            self.kill_change_stream_cursor(change_stream)
            change_stream.try_next()                        # Resume attempt

        response = listener.results['started'][-1]
        self.assertIsNone(
            response.command["pipeline"][0]["$changeStream"].get("resumeAfter"))
        self.assertIsNotNone(
            response.command["pipeline"][0]["$changeStream"].get("startAfter"))

    # Prose test no. 18
    @client_context.require_version_min(4, 1, 1)
    def test_startafter_resume_uses_resumeafter_after_nonempty_getMore(self):
        # Resume should use resumeAfter after some changes have been returned.
        resume_point = self.get_resume_token()

        client, listener = self._client_with_listener("aggregate")
        with self.change_stream_with_client(
                client, start_after=resume_point) as change_stream:
            self.assertFalse(change_stream._cursor._has_next())  # No changes
            self.watched_collection().insert_one({})
            next(change_stream)                             # Changes
            self.kill_change_stream_cursor(change_stream)
            change_stream.try_next()                        # Resume attempt

        response = listener.results['started'][-1]
        self.assertIsNotNone(
            response.command["pipeline"][0]["$changeStream"].get("resumeAfter"))
        self.assertIsNone(
            response.command["pipeline"][0]["$changeStream"].get("startAfter"))


class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin):
    @classmethod
    @client_context.require_version_min(4, 0, 0, -1)
    @client_context.require_no_mmap
    @client_context.require_no_standalone
    def setUpClass(cls):
        super(TestClusterChangeStream, cls).setUpClass()
        cls.dbs = [cls.db, cls.client.pymongo_test_2]

    @classmethod
    def tearDownClass(cls):
        for db in cls.dbs:
            cls.client.drop_database(db)
        super(TestClusterChangeStream, cls).tearDownClass()

    def change_stream_with_client(self, client, *args, **kwargs):
        return client.watch(*args, **kwargs)

    def generate_invalidate_event(self, change_stream):
        self.skipTest("cluster-level change streams cannot be invalidated")

    def _test_get_invalidate_event(self, change_stream):
        # Cluster-level change streams don't get invalidated.
        pass

    def _test_invalidate_stops_iteration(self, change_stream):
        # Cluster-level change streams don't get invalidated.
        pass

    def _insert_and_check(self, change_stream, db, collname, doc):
        coll = db[collname]
        coll.insert_one(doc)
        change = next(change_stream)
        self.assertEqual(change['operationType'], 'insert')
        self.assertEqual(change['ns'], {'db': db.name,
                                        'coll': collname})
        self.assertEqual(change['fullDocument'], doc)

    def insert_one_and_check(self, change_stream, doc):
        db = random.choice(self.dbs)
        collname = self.id()
        self._insert_and_check(change_stream, db, collname, doc)

    def test_simple(self):
        collnames = self.generate_unique_collnames(3)
        with self.change_stream() as change_stream:
            for db, collname in product(self.dbs, collnames):
                self._insert_and_check(
                    change_stream, db, collname, {'_id': collname}
                )

    def test_aggregate_cursor_blocks(self):
        """Test that an aggregate cursor blocks until a change is readable."""
        with self.client.admin.aggregate(
                [{'$changeStream': {'allChangesForCluster': True}}],
                maxAwaitTimeMS=250) as change_stream:
            self._test_next_blocks(change_stream)

    def test_full_pipeline(self):
        """$changeStream must be the first stage in a change stream pipeline
        sent to the server.
        """
        self._test_full_pipeline({'allChangesForCluster': True})


class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin):
    @classmethod
    @client_context.require_version_min(4, 0, 0, -1)
    @client_context.require_no_mmap
    @client_context.require_no_standalone
    def setUpClass(cls):
        super(TestDatabaseChangeStream, cls).setUpClass()

    def change_stream_with_client(self, client, *args, **kwargs):
        return client[self.db.name].watch(*args, **kwargs)

    def generate_invalidate_event(self, change_stream):
        # Dropping the database invalidates the change stream.
        change_stream._client.drop_database(self.db.name)

    def _test_get_invalidate_event(self, change_stream):
        # Cache collection names.
        dropped_colls = self.db.list_collection_names()
        # Drop the watched database to get an invalidate event.
        self.generate_invalidate_event(change_stream)
        change = change_stream.next()
        # 4.1+ returns "drop" events for each collection in dropped database
        # and a "dropDatabase" event for the database itself.
        if change['operationType'] == 'drop':
            self.assertTrue(change['_id'])
            for _ in range(len(dropped_colls)):
                ns = change['ns']
                self.assertEqual(ns['db'], change_stream._target.name)
                self.assertIn(ns['coll'], dropped_colls)
                change = change_stream.next()
            self.assertEqual(change['operationType'], 'dropDatabase')
            self.assertTrue(change['_id'])
            self.assertEqual(change['ns'], {'db': change_stream._target.name})
            # Get next change.
            change = change_stream.next()
        self.assertTrue(change['_id'])
        self.assertEqual(change['operationType'], 'invalidate')
        self.assertNotIn('ns', change)
        self.assertNotIn('fullDocument', change)
        # The ChangeStream should be dead.
        with self.assertRaises(StopIteration):
            change_stream.next()

    def _test_invalidate_stops_iteration(self, change_stream):
        # Drop the watched database to get an invalidate event.
        change_stream._client.drop_database(self.db.name)
        # Check drop and dropDatabase events.
        for change in change_stream:
            self.assertIn(change['operationType'], (
                'drop', 'dropDatabase', 'invalidate'))
        # Last change must be invalidate.
        self.assertEqual(change['operationType'], 'invalidate')
        # Change stream must not allow further iteration.
        with self.assertRaises(StopIteration):
            change_stream.next()
        with self.assertRaises(StopIteration):
            next(change_stream)

    def _insert_and_check(self, change_stream, collname, doc):
        coll = self.db[collname]
        coll.insert_one(doc)
        change = next(change_stream)
        self.assertEqual(change['operationType'], 'insert')
        self.assertEqual(change['ns'], {'db': self.db.name,
                                        'coll': collname})
        self.assertEqual(change['fullDocument'], doc)

    def insert_one_and_check(self, change_stream, doc):
        self._insert_and_check(change_stream, self.id(), doc)

    def test_simple(self):
        collnames = self.generate_unique_collnames(3)
        with self.change_stream() as change_stream:
            for collname in collnames:
                self._insert_and_check(
                    change_stream, collname, {'_id': uuid.uuid4()})

    def test_isolation(self):
        # Ensure inserts to other dbs don't show up in our ChangeStream.
        other_db = self.client.pymongo_test_temp
        self.assertNotEqual(
            other_db, self.db, msg="Isolation must be tested on separate DBs")
        collname = self.id()
        with self.change_stream() as change_stream:
            other_db[collname].insert_one({'_id': uuid.uuid4()})
            self._insert_and_check(
                change_stream, collname, {'_id': uuid.uuid4()})
        self.client.drop_database(other_db)


class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin,
                                 ProseSpecTestsMixin):
    @classmethod
    @client_context.require_version_min(3, 5, 11)
    @client_context.require_no_mmap
    @client_context.require_no_standalone
    def setUpClass(cls):
        super(TestCollectionChangeStream, cls).setUpClass()

    def setUp(self):
        # Use a new collection for each test.
        self.watched_collection().drop()
        self.watched_collection().insert_one({})

    def change_stream_with_client(self, client, *args, **kwargs):
        return client[self.db.name].get_collection(
            self.watched_collection().name).watch(*args, **kwargs)

    def generate_invalidate_event(self, change_stream):
        # Dropping the collection invalidates the change stream.
        change_stream._target.drop()

    def _test_invalidate_stops_iteration(self, change_stream):
        self.generate_invalidate_event(change_stream)
        # Check drop and dropDatabase events.
        for change in change_stream:
            self.assertIn(change['operationType'], ('drop', 'invalidate'))
        # Last change must be invalidate.
        self.assertEqual(change['operationType'], 'invalidate')
        # Change stream must not allow further iteration.
        with self.assertRaises(StopIteration):
            change_stream.next()
        with self.assertRaises(StopIteration):
            next(change_stream)

    def _test_get_invalidate_event(self, change_stream):
        # Drop the watched database to get an invalidate event.
        change_stream._target.drop()
        change = change_stream.next()
        # 4.1+ returns a "drop" change document.
        if change['operationType'] == 'drop':
            self.assertTrue(change['_id'])
            self.assertEqual(change['ns'], {
                'db': change_stream._target.database.name,
                'coll': change_stream._target.name})
            # Last change should be invalidate.
            change = change_stream.next()
        self.assertTrue(change['_id'])
        self.assertEqual(change['operationType'], 'invalidate')
        self.assertNotIn('ns', change)
        self.assertNotIn('fullDocument', change)
        # The ChangeStream should be dead.
        with self.assertRaises(StopIteration):
            change_stream.next()

    def insert_one_and_check(self, change_stream, doc):
        self.watched_collection().insert_one(doc)
        change = next(change_stream)
        self.assertEqual(change['operationType'], 'insert')
        self.assertEqual(
            change['ns'], {'db': self.watched_collection().database.name,
                           'coll': self.watched_collection().name})
        self.assertEqual(change['fullDocument'], doc)

    def test_raw(self):
        """Test with RawBSONDocument."""
        raw_coll = self.watched_collection(
            codec_options=DEFAULT_RAW_BSON_OPTIONS)
        with raw_coll.watch() as change_stream:
            raw_doc = RawBSONDocument(encode({'_id': 1}))
            self.watched_collection().insert_one(raw_doc)
            change = next(change_stream)
            self.assertIsInstance(change, RawBSONDocument)
            self.assertEqual(change['operationType'], 'insert')
            self.assertEqual(
                change['ns']['db'], self.watched_collection().database.name)
            self.assertEqual(
                change['ns']['coll'], self.watched_collection().name)
            self.assertEqual(change['fullDocument'], raw_doc)

    def test_uuid_representations(self):
        """Test with uuid document _ids and different uuid_representation."""
        for uuid_representation in ALL_UUID_REPRESENTATIONS:
            for id_subtype in (STANDARD, PYTHON_LEGACY):
                options = self.watched_collection().codec_options.with_options(
                    uuid_representation=uuid_representation)
                coll = self.watched_collection(codec_options=options)
                with coll.watch() as change_stream:
                    coll.insert_one(
                        {'_id': Binary(uuid.uuid4().bytes, id_subtype)})
                    _ = change_stream.next()
                    resume_token = change_stream.resume_token

                # Should not error.
                coll.watch(resume_after=resume_token)

    def test_document_id_order(self):
        """Test with document _ids that need their order preserved."""
        random_keys = random.sample(string.ascii_letters,
                                    len(string.ascii_letters))
        random_doc = {'_id': SON([(key, key) for key in random_keys])}
        for document_class in (dict, SON, RawBSONDocument):
            options = self.watched_collection().codec_options.with_options(
                document_class=document_class)
            coll = self.watched_collection(codec_options=options)
            with coll.watch() as change_stream:
                coll.insert_one(random_doc)
                _ = change_stream.next()
                resume_token = change_stream.resume_token

            # The resume token is always a document.
            self.assertIsInstance(resume_token, document_class)
            # Should not error.
            coll.watch(resume_after=resume_token)
            coll.delete_many({})

    def test_read_concern(self):
        """Test readConcern is not validated by the driver."""
        # Read concern 'local' is not allowed for $changeStream.
        coll = self.watched_collection(read_concern=ReadConcern('local'))
        with self.assertRaises(OperationFailure):
            coll.watch()

        # Does not error.
        coll = self.watched_collection(read_concern=ReadConcern('majority'))
        with coll.watch():
            pass


class TestAllLegacyScenarios(IntegrationTest):
    RUN_ON_LOAD_BALANCER = True

    @classmethod
    @client_context.require_connection
    def setUpClass(cls):
        super(TestAllLegacyScenarios, cls).setUpClass()
        cls.listener = AllowListEventListener("aggregate", "getMore")
        cls.client = rs_or_single_client(event_listeners=[cls.listener])

    @classmethod
    def tearDownClass(cls):
        cls.client.close()
        super(TestAllLegacyScenarios, cls).tearDownClass()

    def setUp(self):
        super(TestAllLegacyScenarios, self).setUp()
        self.listener.results.clear()

    def setUpCluster(self, scenario_dict):
        assets = [(scenario_dict["database_name"],
                   scenario_dict["collection_name"]),
                  (scenario_dict.get("database2_name", "db2"),
                   scenario_dict.get("collection2_name", "coll2"))]
        for db, coll in assets:
            self.client.drop_database(db)
            self.client[db].create_collection(coll)

    def setFailPoint(self, scenario_dict):
        fail_point = scenario_dict.get("failPoint")
        if fail_point is None:
            return
        elif not client_context.test_commands_enabled:
            self.skipTest("Test commands must be enabled")

        fail_cmd = SON([('configureFailPoint', 'failCommand')])
        fail_cmd.update(fail_point)
        client_context.client.admin.command(fail_cmd)
        self.addCleanup(
            client_context.client.admin.command,
            'configureFailPoint', fail_cmd['configureFailPoint'], mode='off')

    def assert_list_contents_are_subset(self, superlist, sublist):
        """Check that each element in sublist is a subset of the corresponding
        element in superlist."""
        self.assertEqual(len(superlist), len(sublist))
        for sup, sub in zip(superlist, sublist):
            if isinstance(sub, dict):
                self.assert_dict_is_subset(sup, sub)
                continue
            if isinstance(sub, (list, tuple)):
                self.assert_list_contents_are_subset(sup, sub)
                continue
            self.assertEqual(sup, sub)

    def assert_dict_is_subset(self, superdict, subdict):
        """Check that subdict is a subset of superdict."""
        exempt_fields = ["documentKey", "_id", "getMore"]
        for key, value in iteritems(subdict):
            if key not in superdict:
                self.fail('Key %s not found in %s' % (key, superdict))
            if isinstance(value, dict):
                self.assert_dict_is_subset(superdict[key], value)
                continue
            if isinstance(value, (list, tuple)):
                self.assert_list_contents_are_subset(superdict[key], value)
                continue
            if key in exempt_fields:
                # Only check for presence of these exempt fields, but not value.
                self.assertIn(key, superdict)
            else:
                self.assertEqual(superdict[key], value)

    def check_event(self, event, expectation_dict):
        if event is None:
            self.fail()
        for key, value in iteritems(expectation_dict):
            if isinstance(value, dict):
                self.assert_dict_is_subset(getattr(event, key), value)
            else:
                self.assertEqual(getattr(event, key), value)

    def tearDown(self):
        self.listener.results.clear()


_TEST_PATH = os.path.join(
    os.path.dirname(os.path.realpath(__file__)), 'change_streams')


def camel_to_snake(camel):
    # Regex to convert CamelCase to snake_case.
    snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower()


def get_change_stream(client, scenario_def, test):
    # Get target namespace on which to instantiate change stream
    target = test["target"]
    if target == "collection":
        db = client.get_database(scenario_def["database_name"])
        cs_target = db.get_collection(scenario_def["collection_name"])
    elif target == "database":
        cs_target = client.get_database(scenario_def["database_name"])
    elif target == "client":
        cs_target = client
    else:
        raise ValueError("Invalid target in spec")

    # Construct change stream kwargs dict
    cs_pipeline = test["changeStreamPipeline"]
    options = test["changeStreamOptions"]
    cs_options = {}
    for key, value in iteritems(options):
        cs_options[camel_to_snake(key)] = value
    
    # Create and return change stream
    return cs_target.watch(pipeline=cs_pipeline, **cs_options)


def run_operation(client, operation):
    # Apply specified operations
    opname = camel_to_snake(operation["name"])
    arguments = operation.get("arguments", {})
    if opname == 'rename':
        # Special case for rename operation.
        arguments = {'new_name': arguments["to"]}
    cmd = getattr(client.get_database(
        operation["database"]).get_collection(
        operation["collection"]), opname
    )
    return cmd(**arguments)


def create_test(scenario_def, test):
    def run_scenario(self):
        # Set up
        self.setUpCluster(scenario_def)
        self.setFailPoint(test)
        is_error = test["result"].get("error", False)
        try:
            with get_change_stream(
                self.client, scenario_def, test
            ) as change_stream:
                for operation in test["operations"]:
                    # Run specified operations
                    run_operation(self.client, operation)
                num_expected_changes = len(test["result"].get("success", []))
                changes = [
                    change_stream.next() for _ in range(num_expected_changes)]
                # Run a next() to induce an error if one is expected and
                # there are no changes.
                if is_error and not changes:
                    change_stream.next()

        except OperationFailure as exc:
            if not is_error:
                raise
            expected_code = test["result"]["error"]["code"]
            self.assertEqual(exc.code, expected_code)

        else:
            # Check for expected output from change streams
            if test["result"].get("success"):
                for change, expected_changes in zip(changes, test["result"]["success"]):
                    self.assert_dict_is_subset(change, expected_changes)
                self.assertEqual(len(changes), len(test["result"]["success"]))
        
        finally:
            # Check for expected events
            results = self.listener.results
            # Note: expectations may be missing, null, or a list of events.
            # Extra events emitted by the test are intentionally ignored.
            for idx, expectation in enumerate(test.get("expectations") or []):
                for event_type, event_desc in iteritems(expectation):
                    results_key = event_type.split("_")[1]
                    event = results[results_key][idx] if len(results[results_key]) > idx else None
                    self.check_event(event, event_desc)

    return run_scenario


def create_tests():
    for dirpath, _, filenames in os.walk(os.path.join(_TEST_PATH, 'legacy')):
        dirname = os.path.split(dirpath)[-1]

        for filename in filenames:
            with open(os.path.join(dirpath, filename)) as scenario_stream:
                scenario_def = json_util.loads(scenario_stream.read())

            test_type = os.path.splitext(filename)[0]

            for test in scenario_def['tests']:
                new_test = create_test(scenario_def, test)
                new_test = client_context.require_no_mmap(new_test)

                if 'minServerVersion' in test:
                    min_ver = tuple(
                        int(elt) for
                        elt in test['minServerVersion'].split('.'))
                    new_test = client_context.require_version_min(*min_ver)(
                        new_test)
                if 'maxServerVersion' in test:
                    max_ver = tuple(
                        int(elt) for
                        elt in test['maxServerVersion'].split('.'))
                    new_test = client_context.require_version_max(*max_ver)(
                        new_test)

                topologies = test['topology']
                new_test = client_context.require_cluster_type(topologies)(
                    new_test)

                test_name = 'test_%s_%s_%s' % (
                    dirname,
                    test_type.replace("-", "_"),
                    str(test['description'].replace(" ", "_")))

                new_test.__name__ = test_name
                setattr(TestAllLegacyScenarios, new_test.__name__, new_test)


create_tests()


globals().update(generate_test_classes(
    os.path.join(_TEST_PATH, 'unified'),
    module=__name__,))


if __name__ == '__main__':
    unittest.main()