Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

hemamaps / celery   python

Repository URL to install this package:

Version: 3.1.23 

/ tests / worker / test_worker.py

from __future__ import absolute_import, print_function

import os
import socket

from collections import deque
from datetime import datetime, timedelta
from threading import Event

from amqp import ChannelError
from kombu import Connection
from kombu.common import QoS, ignore_errors
from kombu.transport.base import Message

from celery.app.defaults import DEFAULTS
from celery.bootsteps import RUN, CLOSE, StartStopStep
from celery.concurrency.base import BasePool
from celery.datastructures import AttributeDict
from celery.exceptions import (
    WorkerShutdown, WorkerTerminate, TaskRevokedError,
)
from celery.five import Empty, range, Queue as FastQueue
from celery.utils import uuid
from celery.worker import components
from celery.worker import consumer
from celery.worker.consumer import Consumer as __Consumer
from celery.worker.job import Request
from celery.utils import worker_direct
from celery.utils.serialization import pickle
from celery.utils.timer2 import Timer

from celery.tests.case import AppCase, Mock, SkipTest, patch, restore_logging


def MockStep(step=None):
    step = Mock() if step is None else step
    step.blueprint = Mock()
    step.blueprint.name = 'MockNS'
    step.name = 'MockStep(%s)' % (id(step), )
    return step


def mock_event_dispatcher():
    evd = Mock(name='event_dispatcher')
    evd.groups = ['worker']
    evd._outbound_buffer = deque()
    return evd


class PlaceHolder(object):
        pass


def find_step(obj, typ):
    return obj.blueprint.steps[typ.name]


class Consumer(__Consumer):

    def __init__(self, *args, **kwargs):
        kwargs.setdefault('without_mingle', True)  # disable Mingle step
        kwargs.setdefault('without_gossip', True)  # disable Gossip step
        kwargs.setdefault('without_heartbeat', True)  # disable Heart step
        super(Consumer, self).__init__(*args, **kwargs)


class _MyKombuConsumer(Consumer):
    broadcast_consumer = Mock()
    task_consumer = Mock()

    def __init__(self, *args, **kwargs):
        kwargs.setdefault('pool', BasePool(2))
        super(_MyKombuConsumer, self).__init__(*args, **kwargs)

    def restart_heartbeat(self):
        self.heart = None


class MyKombuConsumer(Consumer):

    def loop(self, *args, **kwargs):
        pass


class MockNode(object):
    commands = []

    def handle_message(self, body, message):
        self.commands.append(body.pop('command', None))


class MockEventDispatcher(object):
    sent = []
    closed = False
    flushed = False
    _outbound_buffer = []

    def send(self, event, *args, **kwargs):
        self.sent.append(event)

    def close(self):
        self.closed = True

    def flush(self):
        self.flushed = True


class MockHeart(object):
    closed = False

    def stop(self):
        self.closed = True


def create_message(channel, **data):
    data.setdefault('id', uuid())
    channel.no_ack_consumers = set()
    m = Message(channel, body=pickle.dumps(dict(**data)),
                content_type='application/x-python-serialize',
                content_encoding='binary',
                delivery_info={'consumer_tag': 'mock'})
    m.accept = ['application/x-python-serialize']
    return m


class test_Consumer(AppCase):

    def setup(self):
        self.buffer = FastQueue()
        self.timer = Timer()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z
        self.foo_task = foo_task

    def teardown(self):
        self.timer.stop()

    def test_info(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        l.connection = Mock()
        l.connection.info.return_value = {'foo': 'bar'}
        l.controller = l.app.WorkController()
        l.controller.pool = Mock()
        l.controller.pool.info.return_value = [Mock(), Mock()]
        l.controller.consumer = l
        info = l.controller.stats()
        self.assertEqual(info['prefetch_count'], 10)
        self.assertTrue(info['broker'])

    def test_start_when_closed(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = CLOSE
        l.start()

    def test_connection(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)

        l.blueprint.start(l)
        self.assertIsInstance(l.connection, Connection)

        l.blueprint.state = RUN
        l.event_dispatcher = None
        l.blueprint.restart(l)
        self.assertTrue(l.connection)

        l.blueprint.state = RUN
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

        l.blueprint.start(l)
        self.assertIsInstance(l.connection, Connection)
        l.blueprint.restart(l)

        l.stop()
        l.shutdown()
        self.assertIsNone(l.connection)
        self.assertIsNone(l.task_consumer)

    def test_close_connection(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        step = find_step(l, consumer.Connection)
        conn = l.connection = Mock()
        step.shutdown(l)
        self.assertTrue(conn.close.called)
        self.assertIsNone(l.connection)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        eventer = l.event_dispatcher = mock_event_dispatcher()
        eventer.enabled = True
        heart = l.heart = MockHeart()
        l.blueprint.state = RUN
        Events = find_step(l, consumer.Events)
        Events.shutdown(l)
        Heart = find_step(l, consumer.Heart)
        Heart.shutdown(l)
        self.assertTrue(eventer.close.call_count)
        self.assertTrue(heart.closed)

    @patch('celery.worker.consumer.warn')
    def test_receive_message_unknown(self, warn):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        m = create_message(backend, unknown={'baz': '!!!'})
        l.event_dispatcher = mock_event_dispatcher()
        l.node = MockNode()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertTrue(warn.call_count)

    @patch('celery.worker.strategy.to_timestamp')
    def test_receive_message_eta_OverflowError(self, to_timestamp):
        to_timestamp.side_effect = OverflowError()
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        m = create_message(Mock(), task=self.foo_task.name,
                           args=('2, 2'),
                           kwargs={},
                           eta=datetime.now().isoformat())
        l.event_dispatcher = mock_event_dispatcher()
        l.node = MockNode()
        l.update_strategies()
        l.qos = Mock()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertTrue(m.acknowledged)

    @patch('celery.worker.consumer.error')
    def test_receive_message_InvalidTaskError(self, error):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.event_dispatcher = mock_event_dispatcher()
        l.steps.pop()
        m = create_message(Mock(), task=self.foo_task.name,
                           args=(1, 2), kwargs='foobarbaz', id=1)
        l.update_strategies()
        l.event_dispatcher = mock_event_dispatcher()

        callback = self._get_on_message(l)
        callback(m.decode(), m)
        self.assertIn('Received invalid task message', error.call_args[0][0])

    @patch('celery.worker.consumer.crit')
    def test_on_decode_error(self, crit):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)

        class MockMessage(Mock):
            content_type = 'application/x-msgpack'
            content_encoding = 'binary'
            body = 'foobarbaz'

        message = MockMessage()
        l.on_decode_error(message, KeyError('foo'))
        self.assertTrue(message.ack.call_count)
        self.assertIn("Can't decode message body", crit.call_args[0][0])

    def _get_on_message(self, l):
        if l.qos is None:
            l.qos = Mock()
        l.event_dispatcher = mock_event_dispatcher()
        l.task_consumer = Mock()
        l.connection = Mock()
        l.connection.drain_events.side_effect = WorkerShutdown()

        with self.assertRaises(WorkerShutdown):
            l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.register_callback.called)
        return l.task_consumer.register_callback.call_args[0][0]

    def test_receieve_message(self):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.event_dispatcher = mock_event_dispatcher()
        m = create_message(Mock(), task=self.foo_task.name,
                           args=[2, 4, 8], kwargs={})
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m.decode(), m)

        in_bucket = self.buffer.get_nowait()
        self.assertIsInstance(in_bucket, Request)
        self.assertEqual(in_bucket.name, self.foo_task.name)
        self.assertEqual(in_bucket.execute(), 2 * 4 * 8)
        self.assertTrue(self.timer.empty())

    def test_start_channel_error(self):

        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.buffer.put, timer=self.timer,
                         send_events=False, pool=BasePool(), app=self.app)
        l.channel_errors = (KeyError, )
        with self.assertRaises(KeyError):
            l.start()
        l.timer.stop()

    def test_start_connection_error(self):

        class MockConsumer(Consumer):
            iterations = 0

            def loop(self, *args, **kwargs):
                if not self.iterations:
                    self.iterations = 1
                    raise KeyError('foo')
                raise SyntaxError('bar')

        l = MockConsumer(self.buffer.put, timer=self.timer,
                         send_events=False, pool=BasePool(), app=self.app)

        l.connection_errors = (KeyError, )
        self.assertRaises(SyntaxError, l.start)
        l.timer.stop()

    def test_loop_ignores_socket_timeout(self):

        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.timeout(10)

        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection = Connection()
        l.task_consumer = Mock()
        l.connection.obj = l
        l.qos = QoS(l.task_consumer.qos, 10)
        l.loop(*l.loop_args())

    def test_loop_when_socket_error(self):

        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None
                raise socket.error('foo')

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        c = l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)
        with self.assertRaises(socket.error):
            l.loop(*l.loop_args())

        l.blueprint.state = CLOSE
        l.connection = c
        l.loop(*l.loop_args())

    def test_loop(self):

        class Connection(self.app.connection().__class__):
            obj = None

            def drain_events(self, **kwargs):
                self.obj.connection = None

        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.connection = Connection()
        l.connection.obj = l
        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 10)

        l.loop(*l.loop_args())
        l.loop(*l.loop_args())
        self.assertTrue(l.task_consumer.consume.call_count)
        l.task_consumer.qos.assert_called_with(prefetch_count=10)
        self.assertEqual(l.qos.value, 10)
        l.qos.decrement_eventually()
        self.assertEqual(l.qos.value, 9)
        l.qos.update()
        self.assertEqual(l.qos.value, 9)
        l.task_consumer.qos.assert_called_with(prefetch_count=9)

    def test_ignore_errors(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.connection_errors = (AttributeError, KeyError, )
        l.channel_errors = (SyntaxError, )
        ignore_errors(l, Mock(side_effect=AttributeError('foo')))
        ignore_errors(l, Mock(side_effect=KeyError('foo')))
        ignore_errors(l, Mock(side_effect=SyntaxError('foo')))
        with self.assertRaises(IndexError):
            ignore_errors(l, Mock(side_effect=IndexError('foo')))

    def test_apply_eta_task(self):
        from celery.worker import state
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.qos = QoS(None, 10)

        task = object()
        qos = l.qos.value
        l.apply_eta_task(task)
        self.assertIn(task, state.reserved_requests)
        self.assertEqual(l.qos.value, qos - 1)
        self.assertIs(self.buffer.get_nowait(), task)

    def test_receieve_message_eta_isoformat(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        m = create_message(
            Mock(), task=self.foo_task.name,
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
            args=[2, 4, 8], kwargs={},
        )

        l.task_consumer = Mock()
        l.qos = QoS(l.task_consumer.qos, 1)
        current_pcount = l.qos.value
        l.event_dispatcher = mock_event_dispatcher()
        l.enabled = False
        l.update_strategies()
        callback = self._get_on_message(l)
        callback(m.decode(), m)
        l.timer.stop()
        l.timer.join(1)

        items = [entry[2] for entry in self.timer.queue]
        found = 0
        for item in items:
            if item.args[0].name == self.foo_task.name:
                found = True
        self.assertTrue(found)
        self.assertGreater(l.qos.value, current_pcount)
        l.timer.stop()

    def test_pidbox_callback(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        con.reset = Mock()

        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = KeyError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')

        con.node = Mock()
        con.node.handle_message.side_effect = ValueError('foo')
        con.on_message('foo', 'bar')
        con.node.handle_message.assert_called_with('foo', 'bar')
        self.assertTrue(con.reset.called)

    def test_revoke(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        id = uuid()
        t = create_message(backend, task=self.foo_task.name, args=[2, 4, 8],
                           kwargs={}, id=id)
        from celery.worker.state import revoked
        revoked.add(id)

        callback = self._get_on_message(l)
        callback(t.decode(), t)
        self.assertTrue(self.buffer.empty())

    def test_receieve_message_not_registered(self):
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        l.steps.pop()
        backend = Mock()
        m = create_message(backend, task='x.X.31x', args=[2, 4, 8], kwargs={})

        l.event_dispatcher = mock_event_dispatcher()
        callback = self._get_on_message(l)
        self.assertFalse(callback(m.decode(), m))
        with self.assertRaises(Empty):
            self.buffer.get_nowait()
        self.assertTrue(self.timer.empty())

    @patch('celery.worker.consumer.warn')
    @patch('celery.worker.consumer.logger')
    def test_receieve_message_ack_raises(self, logger, warn):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.blueprint.state = RUN
        backend = Mock()
        m = create_message(backend, args=[2, 4, 8], kwargs={})

        l.event_dispatcher = mock_event_dispatcher()
        l.connection_errors = (socket.error, )
        m.reject = Mock()
        m.reject.side_effect = socket.error('foo')
        callback = self._get_on_message(l)
        self.assertFalse(callback(m.decode(), m))
        self.assertTrue(warn.call_count)
        with self.assertRaises(Empty):
            self.buffer.get_nowait()
        self.assertTrue(self.timer.empty())
        m.reject.assert_called_with(requeue=False)
        self.assertTrue(logger.critical.call_count)

    def test_receive_message_eta(self):
        import sys
        from functools import partial
        if os.environ.get('C_DEBUG_TEST'):
            pp = partial(print, file=sys.__stderr__)
        else:
            def pp(*args, **kwargs):
                pass
        pp('TEST RECEIVE MESSAGE ETA')
        pp('+CREATE MYKOMBUCONSUMER')
        l = _MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        pp('-CREATE MYKOMBUCONSUMER')
        l.steps.pop()
        l.event_dispatcher = mock_event_dispatcher()
        backend = Mock()
        pp('+ CREATE MESSAGE')
        m = create_message(
            backend, task=self.foo_task.name,
            args=[2, 4, 8], kwargs={},
            eta=(datetime.now() + timedelta(days=1)).isoformat(),
        )
        pp('- CREATE MESSAGE')

        try:
            pp('+ BLUEPRINT START 1')
            l.blueprint.start(l)
            pp('- BLUEPRINT START 1')
            p = l.app.conf.BROKER_CONNECTION_RETRY
            l.app.conf.BROKER_CONNECTION_RETRY = False
            pp('+ BLUEPRINT START 2')
            l.blueprint.start(l)
            pp('- BLUEPRINT START 2')
            l.app.conf.BROKER_CONNECTION_RETRY = p
            pp('+ BLUEPRINT RESTART')
            l.blueprint.restart(l)
            pp('- BLUEPRINT RESTART')
            l.event_dispatcher = mock_event_dispatcher()
            pp('+ GET ON MESSAGE')
            callback = self._get_on_message(l)
            pp('- GET ON MESSAGE')
            pp('+ CALLBACK')
            callback(m.decode(), m)
            pp('- CALLBACK')
        finally:
            pp('+ STOP TIMER')
            l.timer.stop()
            pp('- STOP TIMER')
            try:
                pp('+ JOIN TIMER')
                l.timer.join()
                pp('- JOIN TIMER')
            except RuntimeError:
                pass

        in_hold = l.timer.queue[0]
        self.assertEqual(len(in_hold), 3)
        eta, priority, entry = in_hold
        task = entry.args[0]
        self.assertIsInstance(task, Request)
        self.assertEqual(task.name, self.foo_task.name)
        self.assertEqual(task.execute(), 2 * 4 * 8)
        with self.assertRaises(Empty):
            self.buffer.get_nowait()

    def test_reset_pidbox_node(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        con = find_step(l, consumer.Control).box
        con.node = Mock()
        chan = con.node.channel = Mock()
        l.connection = Mock()
        chan.close.side_effect = socket.error('foo')
        l.connection_errors = (socket.error, )
        con.reset()
        chan.close.assert_called_with()

    def test_reset_pidbox_node_green(self):
        from celery.worker.pidbox import gPidbox
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
                            app=self.app)
        con = find_step(l, consumer.Control)
        self.assertIsInstance(con.box, gPidbox)
        con.start(l)
        l.pool.spawn_n.assert_called_with(
            con.box.loop, l,
        )

    def test__green_pidbox_node(self):
        pool = Mock()
        pool.is_green = True
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, pool=pool,
                            app=self.app)
        l.node = Mock()
        controller = find_step(l, consumer.Control)

        class BConsumer(Mock):

            def __enter__(self):
                self.consume()
                return self

            def __exit__(self, *exc_info):
                self.cancel()

        controller.box.node.listen = BConsumer()
        connections = []

        class Connection(object):
            calls = 0

            def __init__(self, obj):
                connections.append(self)
                self.obj = obj
                self.default_channel = self.channel()
                self.closed = False

            def __enter__(self):
                return self

            def __exit__(self, *exc_info):
                self.close()

            def channel(self):
                return Mock()

            def as_uri(self):
                return 'dummy://'

            def drain_events(self, **kwargs):
                if not self.calls:
                    self.calls += 1
                    raise socket.timeout()
                self.obj.connection = None
                controller.box._node_shutdown.set()

            def close(self):
                self.closed = True

        l.connection = Mock()
        l.connect = lambda: Connection(obj=l)
        controller = find_step(l, consumer.Control)
        controller.box.loop(l)

        self.assertTrue(controller.box.node.listen.called)
        self.assertTrue(controller.box.consumer)
        controller.box.consumer.consume.assert_called_with()

        self.assertIsNone(l.connection)
        self.assertTrue(connections[0].closed)

    @patch('kombu.connection.Connection._establish_connection')
    @patch('kombu.utils.sleep')
    def test_connect_errback(self, sleep, connect):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        from kombu.transport.memory import Transport
        Transport.connection_errors = (ChannelError, )

        def effect():
            if connect.call_count > 1:
                return
            raise ChannelError('error')
        connect.side_effect = effect
        l.connect()
        connect.assert_called_with()

    def test_stop_pidbox_node(self):
        l = MyKombuConsumer(self.buffer.put, timer=self.timer, app=self.app)
        cont = find_step(l, consumer.Control)
        cont._node_stopped = Event()
        cont._node_shutdown = Event()
        cont._node_stopped.set()
        cont.stop(l)

    def test_start__loop(self):

        class _QoS(object):
            prev = 3
            value = 4

            def update(self):
                self.prev = self.value

        class _Consumer(MyKombuConsumer):
            iterations = 0

            def reset_connection(self):
                if self.iterations >= 1:
                    raise KeyError('foo')

        init_callback = Mock()
        l = _Consumer(self.buffer.put, timer=self.timer,
                      init_callback=init_callback, app=self.app)
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.qos = _QoS()
        l.connection = Connection()
        l.iterations = 0

        def raises_KeyError(*args, **kwargs):
            l.iterations += 1
            if l.qos.prev != l.qos.value:
                l.qos.update()
            if l.iterations >= 2:
                raise KeyError('foo')

        l.loop = raises_KeyError
        with self.assertRaises(KeyError):
            l.start()
        self.assertEqual(l.iterations, 2)
        self.assertEqual(l.qos.prev, l.qos.value)

        init_callback.reset_mock()
        l = _Consumer(self.buffer.put, timer=self.timer, app=self.app,
                      send_events=False, init_callback=init_callback)
        l.qos = _QoS()
        l.task_consumer = Mock()
        l.broadcast_consumer = Mock()
        l.connection = Connection()
        l.loop = Mock(side_effect=socket.error('foo'))
        with self.assertRaises(socket.error):
            l.start()
        self.assertTrue(l.loop.call_count)

    def test_reset_connection_with_no_node(self):
        l = Consumer(self.buffer.put, timer=self.timer, app=self.app)
        l.steps.pop()
        self.assertEqual(None, l.pool)
        l.blueprint.start(l)


class test_WorkController(AppCase):

    def setup(self):
        self.worker = self.create_worker()
        from celery import worker
        self._logger = worker.logger
        self._comp_logger = components.logger
        self.logger = worker.logger = Mock()
        self.comp_logger = components.logger = Mock()

        @self.app.task(shared=False)
        def foo_task(x, y, z):
            return x * y * z
        self.foo_task = foo_task

    def teardown(self):
        from celery import worker
        worker.logger = self._logger
        components.logger = self._comp_logger

    def create_worker(self, **kw):
        worker = self.app.WorkController(concurrency=1, loglevel=0, **kw)
        worker.blueprint.shutdown_complete.set()
        return worker

    def test_on_consumer_ready(self):
        self.worker.on_consumer_ready(Mock())

    def test_setup_queues_worker_direct(self):
        self.app.conf.CELERY_WORKER_DIRECT = True
        self.app.amqp.__dict__['queues'] = Mock()
        self.worker.setup_queues({})
        self.app.amqp.queues.select_add.assert_called_with(
            worker_direct(self.worker.hostname),
        )

    def test_send_worker_shutdown(self):
        with patch('celery.signals.worker_shutdown') as ws:
            self.worker._send_worker_shutdown()
            ws.send.assert_called_with(sender=self.worker)

    def test_process_shutdown_on_worker_shutdown(self):
        raise SkipTest('unstable test')
        from celery.concurrency.prefork import process_destructor
        from celery.concurrency.asynpool import Worker
        with patch('celery.signals.worker_process_shutdown') as ws:
            Worker._make_shortcuts = Mock()
            with patch('os._exit') as _exit:
                worker = Worker(None, None, on_exit=process_destructor)
                worker._do_exit(22, 3.1415926)
                ws.send.assert_called_with(
                    sender=None, pid=22, exitcode=3.1415926,
                )
                _exit.assert_called_with(3.1415926)

    def test_process_task_revoked_release_semaphore(self):
        self.worker._quick_release = Mock()
        req = Mock()
        req.execute_using_pool.side_effect = TaskRevokedError
        self.worker._process_task(req)
        self.worker._quick_release.assert_called_with()

        delattr(self.worker, '_quick_release')
        self.worker._process_task(req)

    def test_shutdown_no_blueprint(self):
        self.worker.blueprint = None
        self.worker._shutdown()

    @patch('celery.platforms.create_pidlock')
    def test_use_pidfile(self, create_pidlock):
        create_pidlock.return_value = Mock()
        worker = self.create_worker(pidfile='pidfilelockfilepid')
        worker.steps = []
        worker.start()
        self.assertTrue(create_pidlock.called)
        worker.stop()
        self.assertTrue(worker.pidlock.release.called)

    @patch('celery.platforms.signals')
    @patch('celery.platforms.set_mp_process_title')
    def test_process_initializer(self, set_mp_process_title, _signals):
        with restore_logging():
            from celery import signals
            from celery._state import _tls
            from celery.concurrency.prefork import (
                process_initializer, WORKER_SIGRESET, WORKER_SIGIGNORE,
            )

            def on_worker_process_init(**kwargs):
                on_worker_process_init.called = True
            on_worker_process_init.called = False
            signals.worker_process_init.connect(on_worker_process_init)

            def Loader(*args, **kwargs):
                loader = Mock(*args, **kwargs)
                loader.conf = {}
                loader.override_backends = {}
                return loader

            with self.Celery(loader=Loader) as app:
                app.conf = AttributeDict(DEFAULTS)
                process_initializer(app, 'awesome.worker.com')
                _signals.ignore.assert_any_call(*WORKER_SIGIGNORE)
                _signals.reset.assert_any_call(*WORKER_SIGRESET)
                self.assertTrue(app.loader.init_worker.call_count)
                self.assertTrue(on_worker_process_init.called)
                self.assertIs(_tls.current_app, app)
                set_mp_process_title.assert_called_with(
                    'celeryd', hostname='awesome.worker.com',
                )

                with patch('celery.app.trace.setup_worker_optimizations') as S:
                    os.environ['FORKED_BY_MULTIPROCESSING'] = "1"
                    try:
                        process_initializer(app, 'luke.worker.com')
                        S.assert_called_with(app)
                    finally:
                        os.environ.pop('FORKED_BY_MULTIPROCESSING', None)

    def test_attrs(self):
        worker = self.worker
        self.assertIsNotNone(worker.timer)
        self.assertIsInstance(worker.timer, Timer)
        self.assertIsNotNone(worker.pool)
        self.assertIsNotNone(worker.consumer)
        self.assertTrue(worker.steps)

    def test_with_embedded_beat(self):
        worker = self.app.WorkController(concurrency=1, loglevel=0, beat=True)
        self.assertTrue(worker.beat)
        self.assertIn(worker.beat, [w.obj for w in worker.steps])

    def test_with_autoscaler(self):
        worker = self.create_worker(
            autoscale=[10, 3], send_events=False,
            timer_cls='celery.utils.timer2.Timer',
        )
        self.assertTrue(worker.autoscaler)

    def test_dont_stop_or_terminate(self):
        worker = self.app.WorkController(concurrency=1, loglevel=0)
        worker.stop()
        self.assertNotEqual(worker.blueprint.state, CLOSE)
        worker.terminate()
        self.assertNotEqual(worker.blueprint.state, CLOSE)

        sigsafe, worker.pool.signal_safe = worker.pool.signal_safe, False
        try:
            worker.blueprint.state = RUN
            worker.stop(in_sighandler=True)
            self.assertNotEqual(worker.blueprint.state, CLOSE)
            worker.terminate(in_sighandler=True)
            self.assertNotEqual(worker.blueprint.state, CLOSE)
        finally:
            worker.pool.signal_safe = sigsafe

    def test_on_timer_error(self):
        worker = self.app.WorkController(concurrency=1, loglevel=0)

        try:
            raise KeyError('foo')
        except KeyError as exc:
            components.Timer(worker).on_timer_error(exc)
            msg, args = self.comp_logger.error.call_args[0]
            self.assertIn('KeyError', msg % args)

    def test_on_timer_tick(self):
        worker = self.app.WorkController(concurrency=1, loglevel=10)

        components.Timer(worker).on_timer_tick(30.0)
        xargs = self.comp_logger.debug.call_args[0]
        fmt, arg = xargs[0], xargs[1]
        self.assertEqual(30.0, arg)
        self.assertIn('Next eta %s secs', fmt)

    def test_process_task(self):
        worker = self.worker
        worker.pool = Mock()
        backend = Mock()
        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                           kwargs={})
        task = Request(m.decode(), message=m, app=self.app)
        worker._process_task(task)
        self.assertEqual(worker.pool.apply_async.call_count, 1)
        worker.pool.stop()

    def test_process_task_raise_base(self):
        worker = self.worker
        worker.pool = Mock()
        worker.pool.apply_async.side_effect = KeyboardInterrupt('Ctrl+C')
        backend = Mock()
        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                           kwargs={})
        task = Request(m.decode(), message=m, app=self.app)
        worker.steps = []
        worker.blueprint.state = RUN
        with self.assertRaises(KeyboardInterrupt):
            worker._process_task(task)

    def test_process_task_raise_WorkerTerminate(self):
        worker = self.worker
        worker.pool = Mock()
        worker.pool.apply_async.side_effect = WorkerTerminate()
        backend = Mock()
        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                           kwargs={})
        task = Request(m.decode(), message=m, app=self.app)
        worker.steps = []
        worker.blueprint.state = RUN
        with self.assertRaises(SystemExit):
            worker._process_task(task)

    def test_process_task_raise_regular(self):
        worker = self.worker
        worker.pool = Mock()
        worker.pool.apply_async.side_effect = KeyError('some exception')
        backend = Mock()
        m = create_message(backend, task=self.foo_task.name, args=[4, 8, 10],
                           kwargs={})
        task = Request(m.decode(), message=m, app=self.app)
        worker._process_task(task)
        worker.pool.stop()

    def test_start_catches_base_exceptions(self):
        worker1 = self.create_worker()
        worker1.blueprint.state = RUN
        stc = MockStep()
        stc.start.side_effect = WorkerTerminate()
        worker1.steps = [stc]
        worker1.start()
        stc.start.assert_called_with(worker1)
        self.assertTrue(stc.terminate.call_count)

        worker2 = self.create_worker()
        worker2.blueprint.state = RUN
        sec = MockStep()
        sec.start.side_effect = WorkerShutdown()
        sec.terminate = None
        worker2.steps = [sec]
        worker2.start()
        self.assertTrue(sec.stop.call_count)

    def test_state_db(self):
        from celery.worker import state
        Persistent = state.Persistent

        state.Persistent = Mock()
        try:
            worker = self.create_worker(state_db='statefilename')
            self.assertTrue(worker._persistence)
        finally:
            state.Persistent = Persistent

    def test_process_task_sem(self):
        worker = self.worker
        worker._quick_acquire = Mock()

        req = Mock()
        worker._process_task_sem(req)
        worker._quick_acquire.assert_called_with(worker._process_task, req)

    def test_signal_consumer_close(self):
        worker = self.worker
        worker.consumer = Mock()

        worker.signal_consumer_close()
        worker.consumer.close.assert_called_with()

        worker.consumer.close.side_effect = AttributeError()
        worker.signal_consumer_close()

    def test_start__stop(self):
        worker = self.worker
        worker.blueprint.shutdown_complete.set()
        worker.steps = [MockStep(StartStopStep(self)) for _ in range(4)]
        worker.blueprint.state = RUN
        worker.blueprint.started = 4
        for w in worker.steps:
            w.start = Mock()
            w.close = Mock()
            w.stop = Mock()

        worker.start()
        for w in worker.steps:
            self.assertTrue(w.start.call_count)
        worker.consumer = Mock()
        worker.stop()
        for stopstep in worker.steps:
            self.assertTrue(stopstep.close.call_count)
            self.assertTrue(stopstep.stop.call_count)

        # Doesn't close pool if no pool.
        worker.start()
        worker.pool = None
        worker.stop()

        # test that stop of None is not attempted
        worker.steps[-1] = None
        worker.start()
        worker.stop()

    def test_step_raises(self):
        worker = self.worker
        step = Mock()
        worker.steps = [step]
        step.start.side_effect = TypeError()
        worker.stop = Mock()
        worker.start()
        worker.stop.assert_called_with()

    def test_state(self):
        self.assertTrue(self.worker.state)

    def test_start__terminate(self):
        worker = self.worker
        worker.blueprint.shutdown_complete.set()
        worker.blueprint.started = 5
        worker.blueprint.state = RUN
        worker.steps = [MockStep() for _ in range(5)]
        worker.start()
        for w in worker.steps[:3]:
            self.assertTrue(w.start.call_count)
        self.assertTrue(worker.blueprint.started, len(worker.steps))
        self.assertEqual(worker.blueprint.state, RUN)
        worker.terminate()
        for step in worker.steps:
            self.assertTrue(step.terminate.call_count)

    def test_Queues_pool_no_sem(self):
        w = Mock()
        w.pool_cls.uses_semaphore = False
        components.Queues(w).create(w)
        self.assertIs(w.process_task, w._process_task)

    def test_Hub_crate(self):
        w = Mock()
        x = components.Hub(w)
        x.create(w)
        self.assertTrue(w.timer.max_interval)

    def test_Pool_crate_threaded(self):
        w = Mock()
        w._conninfo.connection_errors = w._conninfo.channel_errors = ()
        w.pool_cls = Mock()
        w.use_eventloop = False
        pool = components.Pool(w)
        pool.create(w)

    def test_Pool_create(self):
        from kombu.async.semaphore import LaxBoundedSemaphore
        w = Mock()
        w._conninfo.connection_errors = w._conninfo.channel_errors = ()
        w.hub = Mock()

        PoolImp = Mock()
        poolimp = PoolImp.return_value = Mock()
        poolimp._pool = [Mock(), Mock()]
        poolimp._cache = {}
        poolimp._fileno_to_inq = {}
        poolimp._fileno_to_outq = {}

        from celery.concurrency.prefork import TaskPool as _TaskPool

        class MockTaskPool(_TaskPool):
            Pool = PoolImp

            @property
            def timers(self):
                return {Mock(): 30}

        w.pool_cls = MockTaskPool
        w.use_eventloop = True
        w.consumer.restart_count = -1
        pool = components.Pool(w)
        pool.create(w)
        pool.register_with_event_loop(w, w.hub)
        self.assertIsInstance(w.semaphore, LaxBoundedSemaphore)
        P = w.pool
        P.start()