from __future__ import absolute_import, unicode_literals
from contextlib import contextmanager
from amqp import ChannelError
from kombu import Connection, Producer, Queue, Exchange
from kombu.transport.virtual import QoS
from celery.contrib.migrate import (
StopFiltering,
State,
migrate_task,
migrate_tasks,
filter_callback,
_maybe_queue,
filter_status,
move_by_taskmap,
move_by_idmap,
move_task_by_id,
start_filter,
task_id_in,
task_id_eq,
expand_dest,
move,
)
from celery.utils.encoding import bytes_t, ensure_bytes
from celery.tests.case import AppCase, Mock, override_stdouts, patch
# hack to ignore error at shutdown
QoS.restore_at_shutdown = False
def Message(body, exchange='exchange', routing_key='rkey',
compression=None, content_type='application/json',
content_encoding='utf-8'):
return Mock(
attrs={
'body': body,
'delivery_info': {
'exchange': exchange,
'routing_key': routing_key,
},
'headers': {
'compression': compression,
},
'content_type': content_type,
'content_encoding': content_encoding,
'properties': {}
},
)
class test_State(AppCase):
def test_strtotal(self):
x = State()
self.assertEqual(x.strtotal, '?')
x.total_apx = 100
self.assertEqual(x.strtotal, '100')
def test_repr(self):
x = State()
self.assertTrue(repr(x))
x.filtered = 'foo'
self.assertTrue(repr(x))
class test_move(AppCase):
@contextmanager
def move_context(self, **kwargs):
with patch('celery.contrib.migrate.start_filter') as start:
with patch('celery.contrib.migrate.republish') as republish:
pred = Mock(name='predicate')
move(pred, app=self.app,
connection=self.app.connection(), **kwargs)
self.assertTrue(start.called)
callback = start.call_args[0][2]
yield callback, pred, republish
def msgpair(self, **kwargs):
body = dict({'task': 'add', 'id': 'id'}, **kwargs)
return body, Message(body)
def test_move(self):
with self.move_context() as (callback, pred, republish):
pred.return_value = None
body, message = self.msgpair()
callback(body, message)
self.assertFalse(message.ack.called)
self.assertFalse(republish.called)
pred.return_value = 'foo'
callback(body, message)
message.ack.assert_called_with()
self.assertTrue(republish.called)
def test_move_transform(self):
trans = Mock(name='transform')
trans.return_value = Queue('bar')
with self.move_context(transform=trans) as (callback, pred, republish):
pred.return_value = 'foo'
body, message = self.msgpair()
with patch('celery.contrib.migrate.maybe_declare') as maybed:
callback(body, message)
trans.assert_called_with('foo')
self.assertTrue(maybed.called)
self.assertTrue(republish.called)
def test_limit(self):
with self.move_context(limit=1) as (callback, pred, republish):
pred.return_value = 'foo'
body, message = self.msgpair()
with self.assertRaises(StopFiltering):
callback(body, message)
self.assertTrue(republish.called)
def test_callback(self):
cb = Mock()
with self.move_context(callback=cb) as (callback, pred, republish):
pred.return_value = 'foo'
body, message = self.msgpair()
callback(body, message)
self.assertTrue(republish.called)
self.assertTrue(cb.called)
class test_start_filter(AppCase):
def test_start(self):
with patch('celery.contrib.migrate.eventloop') as evloop:
app = Mock()
filt = Mock(name='filter')
conn = Connection('memory://')
evloop.side_effect = StopFiltering()
app.amqp.queues = {'foo': Queue('foo'), 'bar': Queue('bar')}
consumer = app.amqp.TaskConsumer.return_value = Mock(name='consum')
consumer.queues = list(app.amqp.queues.values())
consumer.channel = conn.default_channel
consumer.__enter__ = Mock(name='consumer.__enter__')
consumer.__exit__ = Mock(name='consumer.__exit__')
consumer.callbacks = []
def register_callback(x):
consumer.callbacks.append(x)
consumer.register_callback = register_callback
start_filter(app, conn, filt,
queues='foo,bar', ack_messages=True)
body = {'task': 'add', 'id': 'id'}
for callback in consumer.callbacks:
callback(body, Message(body))
consumer.callbacks[:] = []
cb = Mock(name='callback=')
start_filter(app, conn, filt, tasks='add,mul', callback=cb)
for callback in consumer.callbacks:
callback(body, Message(body))
self.assertTrue(cb.called)
on_declare_queue = Mock()
start_filter(app, conn, filt, tasks='add,mul', queues='foo',
on_declare_queue=on_declare_queue)
self.assertTrue(on_declare_queue.called)
start_filter(app, conn, filt, queues=['foo', 'bar'])
consumer.callbacks[:] = []
state = State()
start_filter(app, conn, filt,
tasks='add,mul', callback=cb, state=state, limit=1)
stop_filtering_raised = False
for callback in consumer.callbacks:
try:
callback(body, Message(body))
except StopFiltering:
stop_filtering_raised = True
self.assertTrue(state.count)
self.assertTrue(stop_filtering_raised)
class test_filter_callback(AppCase):
def test_filter(self):
callback = Mock()
filt = filter_callback(callback, ['add', 'mul'])
t1 = {'task': 'add'}
t2 = {'task': 'div'}
message = Mock()
filt(t2, message)
self.assertFalse(callback.called)
filt(t1, message)
callback.assert_called_with(t1, message)
class test_utils(AppCase):
def test_task_id_in(self):
self.assertTrue(task_id_in(['A'], {'id': 'A'}, Mock()))
self.assertFalse(task_id_in(['A'], {'id': 'B'}, Mock()))
def test_task_id_eq(self):
self.assertTrue(task_id_eq('A', {'id': 'A'}, Mock()))
self.assertFalse(task_id_eq('A', {'id': 'B'}, Mock()))
def test_expand_dest(self):
self.assertEqual(expand_dest(None, 'foo', 'bar'), ('foo', 'bar'))
self.assertEqual(expand_dest(('b', 'x'), 'foo', 'bar'), ('b', 'x'))
def test_maybe_queue(self):
app = Mock()
app.amqp.queues = {'foo': 313}
self.assertEqual(_maybe_queue(app, 'foo'), 313)
self.assertEqual(_maybe_queue(app, Queue('foo')), Queue('foo'))
def test_filter_status(self):
with override_stdouts() as (stdout, stderr):
filter_status(State(), {'id': '1', 'task': 'add'}, Mock())
self.assertTrue(stdout.getvalue())
def test_move_by_taskmap(self):
with patch('celery.contrib.migrate.move') as move:
move_by_taskmap({'add': Queue('foo')})
self.assertTrue(move.called)
cb = move.call_args[0][0]
self.assertTrue(cb({'task': 'add'}, Mock()))
def test_move_by_idmap(self):
with patch('celery.contrib.migrate.move') as move:
move_by_idmap({'123f': Queue('foo')})
self.assertTrue(move.called)
cb = move.call_args[0][0]
self.assertTrue(cb({'id': '123f'}, Mock()))
def test_move_task_by_id(self):
with patch('celery.contrib.migrate.move') as move:
move_task_by_id('123f', Queue('foo'))
self.assertTrue(move.called)
cb = move.call_args[0][0]
self.assertEqual(
cb({'id': '123f'}, Mock()),
Queue('foo'),
)
class test_migrate_task(AppCase):
def test_removes_compression_header(self):
x = Message('foo', compression='zlib')
producer = Mock()
migrate_task(producer, x.body, x)
self.assertTrue(producer.publish.called)
args, kwargs = producer.publish.call_args
self.assertIsInstance(args[0], bytes_t)
self.assertNotIn('compression', kwargs['headers'])
self.assertEqual(kwargs['compression'], 'zlib')
self.assertEqual(kwargs['content_type'], 'application/json')
self.assertEqual(kwargs['content_encoding'], 'utf-8')
self.assertEqual(kwargs['exchange'], 'exchange')
self.assertEqual(kwargs['routing_key'], 'rkey')
class test_migrate_tasks(AppCase):
def test_migrate(self, name='testcelery'):
x = Connection('memory://foo')
y = Connection('memory://foo')
# use separate state
x.default_channel.queues = {}
y.default_channel.queues = {}
ex = Exchange(name, 'direct')
q = Queue(name, exchange=ex, routing_key=name)
q(x.default_channel).declare()
Producer(x).publish('foo', exchange=name, routing_key=name)
Producer(x).publish('bar', exchange=name, routing_key=name)
Producer(x).publish('baz', exchange=name, routing_key=name)
self.assertTrue(x.default_channel.queues)
self.assertFalse(y.default_channel.queues)
migrate_tasks(x, y, accept=['text/plain'], app=self.app)
yq = q(y.default_channel)
self.assertEqual(yq.get().body, ensure_bytes('foo'))
self.assertEqual(yq.get().body, ensure_bytes('bar'))
self.assertEqual(yq.get().body, ensure_bytes('baz'))
Producer(x).publish('foo', exchange=name, routing_key=name)
callback = Mock()
migrate_tasks(x, y,
callback=callback, accept=['text/plain'], app=self.app)
self.assertTrue(callback.called)
migrate = Mock()
Producer(x).publish('baz', exchange=name, routing_key=name)
migrate_tasks(x, y, callback=callback,
migrate=migrate, accept=['text/plain'], app=self.app)
self.assertTrue(migrate.called)
with patch('kombu.transport.virtual.Channel.queue_declare') as qd:
def effect(*args, **kwargs):
if kwargs.get('passive'):
raise ChannelError('some channel error')
return 0, 3, 0
qd.side_effect = effect
migrate_tasks(x, y, app=self.app)
x = Connection('memory://')
x.default_channel.queues = {}
y.default_channel.queues = {}
callback = Mock()
migrate_tasks(x, y,
callback=callback, accept=['text/plain'], app=self.app)
self.assertFalse(callback.called)