Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

hemamaps / celery   python

Repository URL to install this package:

Version: 3.1.23 

/ tests / worker / test_strategy.py

from __future__ import absolute_import

from collections import defaultdict
from contextlib import contextmanager

from kombu.utils.limits import TokenBucket

from celery.worker import state
from celery.utils.timeutils import rate

from celery.tests.case import AppCase, Mock, patch, body_from_sig


class test_default_strategy(AppCase):

    def setup(self):
        @self.app.task(shared=False)
        def add(x, y):
            return x + y

        self.add = add

    class Context(object):

        def __init__(self, sig, s, reserved, consumer, message, body):
            self.sig = sig
            self.s = s
            self.reserved = reserved
            self.consumer = consumer
            self.message = message
            self.body = body

        def __call__(self, **kwargs):
            return self.s(
                self.message, self.body,
                self.message.ack, self.message.reject, [], **kwargs
            )

        def was_reserved(self):
            return self.reserved.called

        def was_rate_limited(self):
            assert not self.was_reserved()
            return self.consumer._limit_task.called

        def was_scheduled(self):
            assert not self.was_reserved()
            assert not self.was_rate_limited()
            return self.consumer.timer.call_at.called

        def event_sent(self):
            return self.consumer.event_dispatcher.send.call_args

        def get_request(self):
            if self.was_reserved():
                return self.reserved.call_args[0][0]
            if self.was_rate_limited():
                return self.consumer._limit_task.call_args[0][0]
            if self.was_scheduled():
                return self.consumer.timer.call_at.call_args[0][0]
            raise ValueError('request not handled')

    @contextmanager
    def _context(self, sig,
                 rate_limits=True, events=True, utc=True, limit=None):
        self.assertTrue(sig.type.Strategy)

        reserved = Mock()
        consumer = Mock()
        consumer.task_buckets = defaultdict(lambda: None)
        if limit:
            bucket = TokenBucket(rate(limit), capacity=1)
            consumer.task_buckets[sig.task] = bucket
        consumer.disable_rate_limits = not rate_limits
        consumer.event_dispatcher.enabled = events
        s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved)
        self.assertTrue(s)

        message = Mock()
        body = body_from_sig(self.app, sig, utc=utc)

        yield self.Context(sig, s, reserved, consumer, message, body)

    def test_when_logging_disabled(self):
        with patch('celery.worker.strategy.logger') as logger:
            logger.isEnabledFor.return_value = False
            with self._context(self.add.s(2, 2)) as C:
                C()
                self.assertFalse(logger.info.called)

    def test_task_strategy(self):
        with self._context(self.add.s(2, 2)) as C:
            C()
            self.assertTrue(C.was_reserved())
            req = C.get_request()
            C.consumer.on_task_request.assert_called_with(req)
            self.assertTrue(C.event_sent())

    def test_when_events_disabled(self):
        with self._context(self.add.s(2, 2), events=False) as C:
            C()
            self.assertTrue(C.was_reserved())
            self.assertFalse(C.event_sent())

    def test_eta_task(self):
        with self._context(self.add.s(2, 2).set(countdown=10)) as C:
            C()
            self.assertTrue(C.was_scheduled())
            C.consumer.qos.increment_eventually.assert_called_with()

    def test_eta_task_utc_disabled(self):
        with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C:
            C()
            self.assertTrue(C.was_scheduled())
            C.consumer.qos.increment_eventually.assert_called_with()

    def test_when_rate_limited(self):
        task = self.add.s(2, 2)
        with self._context(task, rate_limits=True, limit='1/m') as C:
            C()
            self.assertTrue(C.was_rate_limited())

    def test_when_rate_limited__limits_disabled(self):
        task = self.add.s(2, 2)
        with self._context(task, rate_limits=False, limit='1/m') as C:
            C()
            self.assertTrue(C.was_reserved())

    def test_when_revoked(self):
        task = self.add.s(2, 2)
        task.freeze()
        state.revoked.add(task.id)
        try:
            with self._context(task) as C:
                C()
                with self.assertRaises(ValueError):
                    C.get_request()
        finally:
            state.revoked.discard(task.id)