Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

hemamaps / celery   python

Repository URL to install this package:

Version: 3.1.23 

/ tests / backends / test_redis.py

from __future__ import absolute_import

from datetime import timedelta

from pickle import loads, dumps

from celery import signature
from celery import states
from celery import group
from celery import uuid
from celery.datastructures import AttributeDict
from celery.exceptions import ImproperlyConfigured
from celery.utils.timeutils import timedelta_seconds

from celery.tests.case import (
    AppCase, Mock, MockCallbacks, SkipTest, depends_on_current_app, patch,
)


class Connection(object):
    connected = True

    def disconnect(self):
        self.connected = False


class Pipeline(object):

    def __init__(self, client):
        self.client = client
        self.steps = []

    def __getattr__(self, attr):

        def add_step(*args, **kwargs):
            self.steps.append((getattr(self.client, attr), args, kwargs))
            return self
        return add_step

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        pass

    def execute(self):
        return [step(*a, **kw) for step, a, kw in self.steps]


class Redis(MockCallbacks):
    Connection = Connection
    Pipeline = Pipeline

    def __init__(self, host=None, port=None, db=None, password=None, **kw):
        self.host = host
        self.port = port
        self.db = db
        self.password = password
        self.keyspace = {}
        self.expiry = {}
        self.connection = self.Connection()

    def get(self, key):
        return self.keyspace.get(key)

    def setex(self, key, value, expires):
        self.set(key, value)
        self.expire(key, expires)

    def set(self, key, value):
        self.keyspace[key] = value

    def expire(self, key, expires):
        self.expiry[key] = expires
        return expires

    def delete(self, key):
        return bool(self.keyspace.pop(key, None))

    def pipeline(self):
        return self.Pipeline(self)

    def _get_list(self, key):
        try:
            return self.keyspace[key]
        except KeyError:
            l = self.keyspace[key] = []
            return l

    def rpush(self, key, value):
        self._get_list(key).append(value)

    def lrange(self, key, start, stop):
        return self._get_list(key)[start:stop]

    def llen(self, key):
        return len(self.keyspace.get(key) or [])


class redis(object):
    VERSION = (2, 4, 10)
    Redis = Redis

    class ConnectionPool(object):

        def __init__(self, **kwargs):
            pass

    class UnixDomainSocketConnection(object):

        def __init__(self, **kwargs):
            pass


class test_RedisBackend(AppCase):

    def get_backend(self):
        from celery.backends.redis import RedisBackend

        class _RedisBackend(RedisBackend):
            redis = redis

        return _RedisBackend

    def setup(self):
        self.Backend = self.get_backend()

    @depends_on_current_app
    def test_reduce(self):
        try:
            from celery.backends.redis import RedisBackend
            x = RedisBackend(app=self.app, new_join=True)
            self.assertTrue(loads(dumps(x)))
        except ImportError:
            raise SkipTest('redis not installed')

    def test_no_redis(self):
        self.Backend.redis = None
        with self.assertRaises(ImproperlyConfigured):
            self.Backend(app=self.app, new_join=True)

    def test_url(self):
        x = self.Backend(
            'redis://:bosco@vandelay.com:123//1', app=self.app,
            new_join=True,
        )
        self.assertTrue(x.connparams)
        self.assertEqual(x.connparams['host'], 'vandelay.com')
        self.assertEqual(x.connparams['db'], 1)
        self.assertEqual(x.connparams['port'], 123)
        self.assertEqual(x.connparams['password'], 'bosco')

    def test_socket_url(self):
        x = self.Backend(
            'socket:///tmp/redis.sock?virtual_host=/3', app=self.app,
            new_join=True,
        )
        self.assertTrue(x.connparams)
        self.assertEqual(x.connparams['path'], '/tmp/redis.sock')
        self.assertIs(
            x.connparams['connection_class'],
            redis.UnixDomainSocketConnection,
        )
        self.assertNotIn('host', x.connparams)
        self.assertNotIn('port', x.connparams)
        self.assertEqual(x.connparams['db'], 3)

    def test_compat_propertie(self):
        x = self.Backend(
            'redis://:bosco@vandelay.com:123//1', app=self.app,
            new_join=True,
        )
        with self.assertPendingDeprecation():
            self.assertEqual(x.host, 'vandelay.com')
        with self.assertPendingDeprecation():
            self.assertEqual(x.db, 1)
        with self.assertPendingDeprecation():
            self.assertEqual(x.port, 123)
        with self.assertPendingDeprecation():
            self.assertEqual(x.password, 'bosco')

    def test_conf_raises_KeyError(self):
        self.app.conf = AttributeDict({
            'CELERY_RESULT_SERIALIZER': 'json',
            'CELERY_MAX_CACHED_RESULTS': 1,
            'CELERY_ACCEPT_CONTENT': ['json'],
            'CELERY_TASK_RESULT_EXPIRES': None,
        })
        self.Backend(app=self.app, new_join=True)

    def test_expires_defaults_to_config(self):
        self.app.conf.CELERY_TASK_RESULT_EXPIRES = 10
        b = self.Backend(expires=None, app=self.app, new_join=True)
        self.assertEqual(b.expires, 10)

    def test_expires_is_int(self):
        b = self.Backend(expires=48, app=self.app, new_join=True)
        self.assertEqual(b.expires, 48)

    def test_set_new_join_from_url_query(self):
        b = self.Backend('redis://?new_join=True;foobar=1', app=self.app)
        self.assertEqual(b.on_chord_part_return, b._new_chord_return)
        self.assertEqual(b.apply_chord, b._new_chord_apply)

    def test_default_is_old_join(self):
        b = self.Backend(app=self.app)
        self.assertNotEqual(b.on_chord_part_return, b._new_chord_return)
        self.assertNotEqual(b.apply_chord, b._new_chord_apply)

    def test_expires_is_None(self):
        b = self.Backend(expires=None, app=self.app, new_join=True)
        self.assertEqual(b.expires, timedelta_seconds(
            self.app.conf.CELERY_TASK_RESULT_EXPIRES))

    def test_expires_is_timedelta(self):
        b = self.Backend(
            expires=timedelta(minutes=1), app=self.app, new_join=1,
        )
        self.assertEqual(b.expires, 60)

    def test_apply_chord(self):
        self.Backend(app=self.app, new_join=True).apply_chord(
            group(app=self.app), (), 'group_id', {},
            result=[self.app.AsyncResult(x) for x in [1, 2, 3]],
        )

    def test_mget(self):
        b = self.Backend(app=self.app, new_join=True)
        self.assertTrue(b.mget(['a', 'b', 'c']))
        b.client.mget.assert_called_with(['a', 'b', 'c'])

    def test_set_no_expire(self):
        b = self.Backend(app=self.app, new_join=True)
        b.expires = None
        b.set('foo', 'bar')

    @patch('celery.result.GroupResult.restore')
    def test_on_chord_part_return(self, restore):
        b = self.Backend(app=self.app, new_join=True)

        def create_task():
            tid = uuid()
            task = Mock(name='task-{0}'.format(tid))
            task.name = 'foobarbaz'
            self.app.tasks['foobarbaz'] = task
            task.request.chord = signature(task)
            task.request.id = tid
            task.request.chord['chord_size'] = 10
            task.request.group = 'group_id'
            return task

        tasks = [create_task() for i in range(10)]

        for i in range(10):
            b.on_chord_part_return(tasks[i], states.SUCCESS, i)
            self.assertTrue(b.client.rpush.call_count)
            b.client.rpush.reset_mock()
        self.assertTrue(b.client.lrange.call_count)
        gkey = b.get_key_for_group('group_id', '.j')
        b.client.delete.assert_called_with(gkey)
        b.client.expire.assert_called_with(gkey, 86400)

    def test_process_cleanup(self):
        self.Backend(app=self.app, new_join=True).process_cleanup()

    def test_get_set_forget(self):
        b = self.Backend(app=self.app, new_join=True)
        tid = uuid()
        b.store_result(tid, 42, states.SUCCESS)
        self.assertEqual(b.get_status(tid), states.SUCCESS)
        self.assertEqual(b.get_result(tid), 42)
        b.forget(tid)
        self.assertEqual(b.get_status(tid), states.PENDING)

    def test_set_expires(self):
        b = self.Backend(expires=512, app=self.app, new_join=True)
        tid = uuid()
        key = b.get_key_for_task(tid)
        b.store_result(tid, 42, states.SUCCESS)
        b.client.expire.assert_called_with(
            key, 512,
        )