Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

hemamaps / celery   python

Repository URL to install this package:

Version: 3.1.23 

/ tests / worker / test_consumer.py

from __future__ import absolute_import

import errno
import socket

from billiard.exceptions import RestartFreqExceeded

from celery.datastructures import LimitedSet
from celery.worker import state as worker_state
from celery.worker.consumer import (
    Consumer,
    Heart,
    Tasks,
    Agent,
    Mingle,
    Gossip,
    dump_body,
    CLOSE,
)

from celery.tests.case import AppCase, ContextMock, Mock, SkipTest, call, patch


class test_Consumer(AppCase):

    def get_consumer(self, no_hub=False, **kwargs):
        consumer = Consumer(
            on_task_request=Mock(),
            init_callback=Mock(),
            pool=Mock(),
            app=self.app,
            timer=Mock(),
            controller=Mock(),
            hub=None if no_hub else Mock(),
            **kwargs
        )
        consumer.blueprint = Mock()
        consumer._restart_state = Mock()
        consumer.connection = _amqp_connection()
        consumer.connection_errors = (socket.error, OSError, )
        return consumer

    def test_taskbuckets_defaultdict(self):
        c = self.get_consumer()
        self.assertIsNone(c.task_buckets['fooxasdwx.wewe'])

    def test_dump_body_buffer(self):
        msg = Mock()
        msg.body = 'str'
        try:
            buf = buffer(msg.body)
        except NameError:
            raise SkipTest('buffer type not available')
        self.assertTrue(dump_body(msg, buf))

    def test_sets_heartbeat(self):
        c = self.get_consumer(amqheartbeat=10)
        self.assertEqual(c.amqheartbeat, 10)
        self.app.conf.BROKER_HEARTBEAT = 20
        c = self.get_consumer(amqheartbeat=None)
        self.assertEqual(c.amqheartbeat, 20)

    def test_gevent_bug_disables_connection_timeout(self):
        with patch('celery.worker.consumer._detect_environment') as de:
            de.return_value = 'gevent'
            self.app.conf.BROKER_CONNECTION_TIMEOUT = 33.33
            self.get_consumer()
            self.assertIsNone(self.app.conf.BROKER_CONNECTION_TIMEOUT)

    def test_limit_task(self):
        c = self.get_consumer()

        with patch('celery.worker.consumer.task_reserved') as reserved:
            bucket = Mock()
            request = Mock()
            bucket.can_consume.return_value = True

            c._limit_task(request, bucket, 3)
            bucket.can_consume.assert_called_with(3)
            reserved.assert_called_with(request)
            c.on_task_request.assert_called_with(request)

        with patch('celery.worker.consumer.task_reserved') as reserved:
            bucket.can_consume.return_value = False
            bucket.expected_time.return_value = 3.33
            c._limit_task(request, bucket, 4)
            bucket.can_consume.assert_called_with(4)
            c.timer.call_after.assert_called_with(
                3.33, c._limit_task, (request, bucket, 4),
            )
            bucket.expected_time.assert_called_with(4)
            self.assertFalse(reserved.called)

    def test_start_blueprint_raises_EMFILE(self):
        c = self.get_consumer()
        exc = c.blueprint.start.side_effect = OSError()
        exc.errno = errno.EMFILE

        with self.assertRaises(OSError):
            c.start()

    def test_max_restarts_exceeded(self):
        c = self.get_consumer()

        def se(*args, **kwargs):
            c.blueprint.state = CLOSE
            raise RestartFreqExceeded()
        c._restart_state.step.side_effect = se
        c.blueprint.start.side_effect = socket.error()

        with patch('celery.worker.consumer.sleep') as sleep:
            c.start()
            sleep.assert_called_with(1)

    def _closer(self, c):
        def se(*args, **kwargs):
            c.blueprint.state = CLOSE
        return se

    def test_collects_at_restart(self):
        c = self.get_consumer()
        c.connection.collect.side_effect = MemoryError()
        c.blueprint.start.side_effect = socket.error()
        c.blueprint.restart.side_effect = self._closer(c)
        c.start()
        c.connection.collect.assert_called_with()

    def test_register_with_event_loop(self):
        c = self.get_consumer()
        c.register_with_event_loop(Mock(name='loop'))

    def test_on_close_clears_semaphore_timer_and_reqs(self):
        with patch('celery.worker.consumer.reserved_requests') as reserved:
            c = self.get_consumer()
            c.on_close()
            c.controller.semaphore.clear.assert_called_with()
            c.timer.clear.assert_called_with()
            reserved.clear.assert_called_with()
            c.pool.flush.assert_called_with()

            c.controller = None
            c.timer = None
            c.pool = None
            c.on_close()

    def test_connect_error_handler(self):
        self.app.connection = _amqp_connection()
        conn = self.app.connection.return_value
        c = self.get_consumer()
        self.assertTrue(c.connect())
        self.assertTrue(conn.ensure_connection.called)
        errback = conn.ensure_connection.call_args[0][0]
        conn.alt = [(1, 2, 3)]
        errback(Mock(), 0)


class test_Heart(AppCase):

    def test_start(self):
        c = Mock()
        c.timer = Mock()
        c.event_dispatcher = Mock()

        with patch('celery.worker.heartbeat.Heart') as hcls:
            h = Heart(c)
            self.assertTrue(h.enabled)
            self.assertEqual(h.heartbeat_interval, None)
            self.assertIsNone(c.heart)

            h.start(c)
            self.assertTrue(c.heart)
            hcls.assert_called_with(c.timer, c.event_dispatcher,
                                    h.heartbeat_interval)
            c.heart.start.assert_called_with()

    def test_start_heartbeat_interval(self):
        c = Mock()
        c.timer = Mock()
        c.event_dispatcher = Mock()

        with patch('celery.worker.heartbeat.Heart') as hcls:
            h = Heart(c, False, 20)
            self.assertTrue(h.enabled)
            self.assertEqual(h.heartbeat_interval, 20)
            self.assertIsNone(c.heart)

            h.start(c)
            self.assertTrue(c.heart)
            hcls.assert_called_with(c.timer, c.event_dispatcher,
                                    h.heartbeat_interval)
            c.heart.start.assert_called_with()


class test_Tasks(AppCase):

    def test_stop(self):
        c = Mock()
        tasks = Tasks(c)
        self.assertIsNone(c.task_consumer)
        self.assertIsNone(c.qos)

        c.task_consumer = Mock()
        tasks.stop(c)

    def test_stop_already_stopped(self):
        c = Mock()
        tasks = Tasks(c)
        tasks.stop(c)


class test_Agent(AppCase):

    def test_start(self):
        c = Mock()
        agent = Agent(c)
        agent.instantiate = Mock()
        agent.agent_cls = 'foo:Agent'
        self.assertIsNotNone(agent.create(c))
        agent.instantiate.assert_called_with(agent.agent_cls, c.connection)


class test_Mingle(AppCase):

    def test_start_no_replies(self):
        c = Mock()
        c.app.connection = _amqp_connection()
        mingle = Mingle(c)
        I = c.app.control.inspect.return_value = Mock()
        I.hello.return_value = {}
        mingle.start(c)

    def test_start(self):
        try:
            c = Mock()
            c.app.connection = _amqp_connection()
            mingle = Mingle(c)
            self.assertTrue(mingle.enabled)

            Aig = LimitedSet()
            Big = LimitedSet()
            Aig.add('Aig-1')
            Aig.add('Aig-2')
            Big.add('Big-1')

            I = c.app.control.inspect.return_value = Mock()
            I.hello.return_value = {
                'A@example.com': {
                    'clock': 312,
                    'revoked': Aig._data,
                },
                'B@example.com': {
                    'clock': 29,
                    'revoked': Big._data,
                },
                'C@example.com': {
                    'error': 'unknown method',
                },
            }

            mingle.start(c)
            I.hello.assert_called_with(c.hostname, worker_state.revoked._data)
            c.app.clock.adjust.assert_has_calls([
                call(312), call(29),
            ], any_order=True)
            self.assertIn('Aig-1', worker_state.revoked)
            self.assertIn('Aig-2', worker_state.revoked)
            self.assertIn('Big-1', worker_state.revoked)
        finally:
            worker_state.revoked.clear()


def _amqp_connection():
    connection = ContextMock()
    connection.return_value = ContextMock()
    connection.return_value.transport.driver_type = 'amqp'
    return connection


class test_Gossip(AppCase):

    def test_init(self):
        c = self.Consumer()
        c.app.connection = _amqp_connection()
        g = Gossip(c)
        self.assertTrue(g.enabled)
        self.assertIs(c.gossip, g)

    def test_callbacks(self):
        c = self.Consumer()
        c.app.connection = _amqp_connection()
        g = Gossip(c)
        on_node_join = Mock(name='on_node_join')
        on_node_join2 = Mock(name='on_node_join2')
        on_node_leave = Mock(name='on_node_leave')
        on_node_lost = Mock(name='on.node_lost')
        g.on.node_join.add(on_node_join)
        g.on.node_join.add(on_node_join2)
        g.on.node_leave.add(on_node_leave)
        g.on.node_lost.add(on_node_lost)

        worker = Mock(name='worker')
        g.on_node_join(worker)
        on_node_join.assert_called_with(worker)
        on_node_join2.assert_called_with(worker)
        g.on_node_leave(worker)
        on_node_leave.assert_called_with(worker)
        g.on_node_lost(worker)
        on_node_lost.assert_called_with(worker)

    def test_election(self):
        c = self.Consumer()
        c.app.connection = _amqp_connection()
        g = Gossip(c)
        g.start(c)
        g.election('id', 'topic', 'action')
        self.assertListEqual(g.consensus_replies['id'], [])
        g.dispatcher.send.assert_called_with(
            'worker-elect', id='id', topic='topic', cver=1, action='action',
        )

    def test_call_task(self):
        c = self.Consumer()
        c.app.connection = _amqp_connection()
        g = Gossip(c)
        g.start(c)

        with patch('celery.worker.consumer.signature') as signature:
            sig = signature.return_value = Mock()
            task = Mock()
            g.call_task(task)
            signature.assert_called_with(task, app=c.app)
            sig.apply_async.assert_called_with()

            sig.apply_async.side_effect = MemoryError()
            with patch('celery.worker.consumer.error') as error:
                g.call_task(task)
                self.assertTrue(error.called)

    def Event(self, id='id', clock=312,
              hostname='foo@example.com', pid=4312,
              topic='topic', action='action', cver=1):
        return {
            'id': id,
            'clock': clock,
            'hostname': hostname,
            'pid': pid,
            'topic': topic,
            'action': action,
            'cver': cver,
        }

    def test_on_elect(self):
        c = self.Consumer()
        c.app.connection = _amqp_connection()
        g = Gossip(c)
        g.start(c)

        event = self.Event('id1')
        g.on_elect(event)
        in_heap = g.consensus_requests['id1']
        self.assertTrue(in_heap)
        g.dispatcher.send.assert_called_with('worker-elect-ack', id='id1')

        event.pop('clock')
        with patch('celery.worker.consumer.error') as error:
            g.on_elect(event)
            self.assertTrue(error.called)

    def Consumer(self, hostname='foo@x.com', pid=4312):
        c = Mock()
        c.app.connection = _amqp_connection()
        c.hostname = hostname
        c.pid = pid
        return c

    def setup_election(self, g, c):
        g.start(c)
        g.clock = self.app.clock
        self.assertNotIn('idx', g.consensus_replies)
        self.assertIsNone(g.on_elect_ack({'id': 'idx'}))

        g.state.alive_workers.return_value = [
            'foo@x.com', 'bar@x.com', 'baz@x.com',
        ]
        g.consensus_replies['id1'] = []
        g.consensus_requests['id1'] = []
        e1 = self.Event('id1', 1, 'foo@x.com')
        e2 = self.Event('id1', 2, 'bar@x.com')
        e3 = self.Event('id1', 3, 'baz@x.com')
        g.on_elect(e1)
        g.on_elect(e2)
        g.on_elect(e3)
        self.assertEqual(len(g.consensus_requests['id1']), 3)

        with patch('celery.worker.consumer.info'):
            g.on_elect_ack(e1)
            self.assertEqual(len(g.consensus_replies['id1']), 1)
            g.on_elect_ack(e2)
            self.assertEqual(len(g.consensus_replies['id1']), 2)
            g.on_elect_ack(e3)
            with self.assertRaises(KeyError):
                g.consensus_replies['id1']

    def test_on_elect_ack_win(self):
        c = self.Consumer(hostname='foo@x.com')  # I will win
        g = Gossip(c)
        handler = g.election_handlers['topic'] = Mock()
        self.setup_election(g, c)
        handler.assert_called_with('action')

    def test_on_elect_ack_lose(self):
        c = self.Consumer(hostname='bar@x.com')  # I will lose
        c.app.connection = _amqp_connection()
        g = Gossip(c)
        handler = g.election_handlers['topic'] = Mock()
        self.setup_election(g, c)
        self.assertFalse(handler.called)

    def test_on_elect_ack_win_but_no_action(self):
        c = self.Consumer(hostname='foo@x.com')  # I will win
        g = Gossip(c)
        g.election_handlers = {}
        with patch('celery.worker.consumer.error') as error:
            self.setup_election(g, c)
            self.assertTrue(error.called)

    def test_on_node_join(self):
        c = self.Consumer()
        g = Gossip(c)
        with patch('celery.worker.consumer.debug') as debug:
            g.on_node_join(c)
            debug.assert_called_with('%s joined the party', 'foo@x.com')

    def test_on_node_leave(self):
        c = self.Consumer()
        g = Gossip(c)
        with patch('celery.worker.consumer.debug') as debug:
            g.on_node_leave(c)
            debug.assert_called_with('%s left', 'foo@x.com')

    def test_on_node_lost(self):
        c = self.Consumer()
        g = Gossip(c)
        with patch('celery.worker.consumer.info') as info:
            g.on_node_lost(c)
            info.assert_called_with('missed heartbeat from %s', 'foo@x.com')

    def test_register_timer(self):
        c = self.Consumer()
        g = Gossip(c)
        g.register_timer()
        c.timer.call_repeatedly.assert_called_with(g.interval, g.periodic)
        tref = g._tref
        g.register_timer()
        tref.cancel.assert_called_with()

    def test_periodic(self):
        c = self.Consumer()
        g = Gossip(c)
        g.on_node_lost = Mock()
        state = g.state = Mock()
        worker = Mock()
        state.workers = {'foo': worker}
        worker.alive = True
        worker.hostname = 'foo'
        g.periodic()

        worker.alive = False
        g.periodic()
        g.on_node_lost.assert_called_with(worker)
        with self.assertRaises(KeyError):
            state.workers['foo']

    def test_on_message(self):
        c = self.Consumer()
        g = Gossip(c)
        self.assertTrue(g.enabled)
        prepare = Mock()
        prepare.return_value = 'worker-online', {}
        c.app.events.State.assert_called_with(
            on_node_join=g.on_node_join,
            on_node_leave=g.on_node_leave,
            max_tasks_in_memory=1,
        )
        g.update_state = Mock()
        worker = Mock()
        g.on_node_join = Mock()
        g.on_node_leave = Mock()
        g.update_state.return_value = worker, 1
        message = Mock()
        message.delivery_info = {'routing_key': 'worker-online'}
        message.headers = {'hostname': 'other'}

        handler = g.event_handlers['worker-online'] = Mock()
        g.on_message(prepare, message)
        handler.assert_called_with(message.payload)
        g.event_handlers = {}

        g.on_message(prepare, message)

        message.delivery_info = {'routing_key': 'worker-offline'}
        prepare.return_value = 'worker-offline', {}
        g.on_message(prepare, message)

        message.delivery_info = {'routing_key': 'worker-baz'}
        prepare.return_value = 'worker-baz', {}
        g.update_state.return_value = worker, 0
        g.on_message(prepare, message)

        message.headers = {'hostname': g.hostname}
        g.on_message(prepare, message)
        g.clock.forward.assert_called_with()