from __future__ import absolute_import
import pickle
from decimal import Decimal
from random import shuffle
from time import time
from itertools import count
from celery import states
from celery.events import Event
from celery.events.state import (
State,
Worker,
Task,
HEARTBEAT_EXPIRE_WINDOW,
HEARTBEAT_DRIFT_MAX,
)
from celery.five import range
from celery.utils import uuid
from celery.tests.case import AppCase, Mock, patch
try:
Decimal(2.6)
except TypeError: # pragma: no cover
# Py2.6: Must first convert float to str
_float_to_decimal = str
else:
def _float_to_decimal(f): # noqa
return f
class replay(object):
def __init__(self, state):
self.state = state
self.rewind()
self.setup()
self.current_clock = 0
def setup(self):
pass
def next_event(self):
ev = self.events[next(self.position)]
ev['local_received'] = ev['timestamp']
try:
self.current_clock = ev['clock']
except KeyError:
ev['clock'] = self.current_clock = self.current_clock + 1
return ev
def __iter__(self):
return self
def __next__(self):
try:
self.state.event(self.next_event())
except IndexError:
raise StopIteration()
next = __next__
def rewind(self):
self.position = count(0)
return self
def play(self):
for _ in self:
pass
class ev_worker_online_offline(replay):
def setup(self):
self.events = [
Event('worker-online', hostname='utest1'),
Event('worker-offline', hostname='utest1'),
]
class ev_worker_heartbeats(replay):
def setup(self):
self.events = [
Event('worker-heartbeat', hostname='utest1',
timestamp=time() - HEARTBEAT_EXPIRE_WINDOW * 2),
Event('worker-heartbeat', hostname='utest1'),
]
class ev_task_states(replay):
def setup(self):
tid = self.tid = uuid()
self.events = [
Event('task-received', uuid=tid, name='task1',
args='(2, 2)', kwargs="{'foo': 'bar'}",
retries=0, eta=None, hostname='utest1'),
Event('task-started', uuid=tid, hostname='utest1'),
Event('task-revoked', uuid=tid, hostname='utest1'),
Event('task-retried', uuid=tid, exception="KeyError('bar')",
traceback='line 2 at main', hostname='utest1'),
Event('task-failed', uuid=tid, exception="KeyError('foo')",
traceback='line 1 at main', hostname='utest1'),
Event('task-succeeded', uuid=tid, result='4',
runtime=0.1234, hostname='utest1'),
]
def QTEV(type, uuid, hostname, clock, name=None, timestamp=None):
"""Quick task event."""
return Event('task-{0}'.format(type), uuid=uuid, hostname=hostname,
clock=clock, name=name, timestamp=timestamp or time())
class ev_logical_clock_ordering(replay):
def __init__(self, state, offset=0, uids=None):
self.offset = offset or 0
self.uids = self.setuids(uids)
super(ev_logical_clock_ordering, self).__init__(state)
def setuids(self, uids):
uids = self.tA, self.tB, self.tC = uids or [uuid(), uuid(), uuid()]
return uids
def setup(self):
offset = self.offset
tA, tB, tC = self.uids
self.events = [
QTEV('received', tA, 'w1', name='tA', clock=offset + 1),
QTEV('received', tB, 'w2', name='tB', clock=offset + 1),
QTEV('started', tA, 'w1', name='tA', clock=offset + 3),
QTEV('received', tC, 'w2', name='tC', clock=offset + 3),
QTEV('started', tB, 'w2', name='tB', clock=offset + 5),
QTEV('retried', tA, 'w1', name='tA', clock=offset + 7),
QTEV('succeeded', tB, 'w2', name='tB', clock=offset + 9),
QTEV('started', tC, 'w2', name='tC', clock=offset + 10),
QTEV('received', tA, 'w3', name='tA', clock=offset + 13),
QTEV('succeded', tC, 'w2', name='tC', clock=offset + 12),
QTEV('started', tA, 'w3', name='tA', clock=offset + 14),
QTEV('succeeded', tA, 'w3', name='TA', clock=offset + 16),
]
def rewind_with_offset(self, offset, uids=None):
self.offset = offset
self.uids = self.setuids(uids or self.uids)
self.setup()
self.rewind()
class ev_snapshot(replay):
def setup(self):
self.events = [
Event('worker-online', hostname='utest1'),
Event('worker-online', hostname='utest2'),
Event('worker-online', hostname='utest3'),
]
for i in range(20):
worker = not i % 2 and 'utest2' or 'utest1'
type = not i % 2 and 'task2' or 'task1'
self.events.append(Event('task-received', name=type,
uuid=uuid(), hostname=worker))
class test_Worker(AppCase):
def test_equality(self):
self.assertEqual(Worker(hostname='foo').hostname, 'foo')
self.assertEqual(
Worker(hostname='foo'), Worker(hostname='foo'),
)
self.assertNotEqual(
Worker(hostname='foo'), Worker(hostname='bar'),
)
self.assertEqual(
hash(Worker(hostname='foo')), hash(Worker(hostname='foo')),
)
self.assertNotEqual(
hash(Worker(hostname='foo')), hash(Worker(hostname='bar')),
)
def test_compatible_with_Decimal(self):
w = Worker('george@vandelay.com')
timestamp, local_received = Decimal(_float_to_decimal(time())), time()
w.event('worker-online', timestamp, local_received, fields={
'hostname': 'george@vandelay.com',
'timestamp': timestamp,
'local_received': local_received,
'freq': Decimal(_float_to_decimal(5.6335431)),
})
self.assertTrue(w.alive)
def test_survives_missing_timestamp(self):
worker = Worker(hostname='foo')
worker.event('heartbeat')
self.assertEqual(worker.heartbeats, [])
def test_repr(self):
self.assertTrue(repr(Worker(hostname='foo')))
def test_drift_warning(self):
worker = Worker(hostname='foo')
with patch('celery.events.state.warn') as warn:
worker.event(None, time() + (HEARTBEAT_DRIFT_MAX * 2), time())
self.assertTrue(warn.called)
self.assertIn('Substantial drift', warn.call_args[0][0])
def test_updates_heartbeat(self):
worker = Worker(hostname='foo')
worker.event(None, time(), time())
self.assertEqual(len(worker.heartbeats), 1)
h1 = worker.heartbeats[0]
worker.event(None, time(), time() - 10)
self.assertEqual(len(worker.heartbeats), 2)
self.assertEqual(worker.heartbeats[-1], h1)
class test_Task(AppCase):
def test_equality(self):
self.assertEqual(Task(uuid='foo').uuid, 'foo')
self.assertEqual(
Task(uuid='foo'), Task(uuid='foo'),
)
self.assertNotEqual(
Task(uuid='foo'), Task(uuid='bar'),
)
self.assertEqual(
hash(Task(uuid='foo')), hash(Task(uuid='foo')),
)
self.assertNotEqual(
hash(Task(uuid='foo')), hash(Task(uuid='bar')),
)
def test_info(self):
task = Task(uuid='abcdefg',
name='tasks.add',
args='(2, 2)',
kwargs='{}',
retries=2,
result=42,
eta=1,
runtime=0.0001,
expires=1,
foo=None,
exception=1,
received=time() - 10,
started=time() - 8,
exchange='celery',
routing_key='celery',
succeeded=time())
self.assertEqual(sorted(list(task._info_fields)),
sorted(task.info().keys()))
self.assertEqual(sorted(list(task._info_fields + ('received', ))),
sorted(task.info(extra=('received', ))))
self.assertEqual(sorted(['args', 'kwargs']),
sorted(task.info(['args', 'kwargs']).keys()))
self.assertFalse(list(task.info('foo')))
def test_ready(self):
task = Task(uuid='abcdefg',
name='tasks.add')
task.event('received', time(), time())
self.assertFalse(task.ready)
task.event('succeeded', time(), time())
self.assertTrue(task.ready)
def test_sent(self):
task = Task(uuid='abcdefg',
name='tasks.add')
task.event('sent', time(), time())
self.assertEqual(task.state, states.PENDING)
def test_merge(self):
task = Task()
task.event('failed', time(), time())
task.event('started', time(), time())
task.event('received', time(), time(), {
'name': 'tasks.add', 'args': (2, 2),
})
self.assertEqual(task.state, states.FAILURE)
self.assertEqual(task.name, 'tasks.add')
self.assertTupleEqual(task.args, (2, 2))
task.event('retried', time(), time())
self.assertEqual(task.state, states.RETRY)
def test_repr(self):
self.assertTrue(repr(Task(uuid='xxx', name='tasks.add')))
class test_State(AppCase):
def test_repr(self):
self.assertTrue(repr(State()))
def test_pickleable(self):
self.assertTrue(pickle.loads(pickle.dumps(State())))
def test_task_logical_clock_ordering(self):
state = State()
r = ev_logical_clock_ordering(state)
tA, tB, tC = r.uids
r.play()
now = list(state.tasks_by_time())
self.assertEqual(now[0][0], tA)
self.assertEqual(now[1][0], tC)
self.assertEqual(now[2][0], tB)
for _ in range(1000):
shuffle(r.uids)
tA, tB, tC = r.uids
r.rewind_with_offset(r.current_clock + 1, r.uids)
r.play()
now = list(state.tasks_by_time())
self.assertEqual(now[0][0], tA)
self.assertEqual(now[1][0], tC)
self.assertEqual(now[2][0], tB)
def test_worker_online_offline(self):
r = ev_worker_online_offline(State())
next(r)
self.assertTrue(r.state.alive_workers())
self.assertTrue(r.state.workers['utest1'].alive)
r.play()
self.assertFalse(r.state.alive_workers())
self.assertFalse(r.state.workers['utest1'].alive)
def test_itertasks(self):
s = State()
s.tasks = {'a': 'a', 'b': 'b', 'c': 'c', 'd': 'd'}
self.assertEqual(len(list(s.itertasks(limit=2))), 2)
def test_worker_heartbeat_expire(self):
r = ev_worker_heartbeats(State())
next(r)
self.assertFalse(r.state.alive_workers())
self.assertFalse(r.state.workers['utest1'].alive)
r.play()
self.assertTrue(r.state.alive_workers())
self.assertTrue(r.state.workers['utest1'].alive)
def test_task_states(self):
r = ev_task_states(State())
# RECEIVED
next(r)
self.assertTrue(r.tid in r.state.tasks)
task = r.state.tasks[r.tid]
self.assertEqual(task.state, states.RECEIVED)
self.assertTrue(task.received)
self.assertEqual(task.timestamp, task.received)
self.assertEqual(task.worker.hostname, 'utest1')
# STARTED
next(r)
self.assertTrue(r.state.workers['utest1'].alive,
'any task event adds worker heartbeat')
self.assertEqual(task.state, states.STARTED)
self.assertTrue(task.started)
self.assertEqual(task.timestamp, task.started)
self.assertEqual(task.worker.hostname, 'utest1')
# REVOKED
next(r)
self.assertEqual(task.state, states.REVOKED)
self.assertTrue(task.revoked)
self.assertEqual(task.timestamp, task.revoked)
self.assertEqual(task.worker.hostname, 'utest1')
# RETRY
next(r)
self.assertEqual(task.state, states.RETRY)
self.assertTrue(task.retried)
self.assertEqual(task.timestamp, task.retried)
self.assertEqual(task.worker.hostname, 'utest1')
self.assertEqual(task.exception, "KeyError('bar')")
self.assertEqual(task.traceback, 'line 2 at main')
# FAILURE
next(r)
self.assertEqual(task.state, states.FAILURE)
self.assertTrue(task.failed)
self.assertEqual(task.timestamp, task.failed)
self.assertEqual(task.worker.hostname, 'utest1')
self.assertEqual(task.exception, "KeyError('foo')")
self.assertEqual(task.traceback, 'line 1 at main')
# SUCCESS
next(r)
self.assertEqual(task.state, states.SUCCESS)
self.assertTrue(task.succeeded)
self.assertEqual(task.timestamp, task.succeeded)
self.assertEqual(task.worker.hostname, 'utest1')
self.assertEqual(task.result, '4')
self.assertEqual(task.runtime, 0.1234)
def assertStateEmpty(self, state):
self.assertFalse(state.tasks)
self.assertFalse(state.workers)
self.assertFalse(state.event_count)
self.assertFalse(state.task_count)
def assertState(self, state):
self.assertTrue(state.tasks)
self.assertTrue(state.workers)
self.assertTrue(state.event_count)
self.assertTrue(state.task_count)
def test_freeze_while(self):
s = State()
r = ev_snapshot(s)
r.play()
def work():
pass
s.freeze_while(work, clear_after=True)
self.assertFalse(s.event_count)
s2 = State()
r = ev_snapshot(s2)
r.play()
s2.freeze_while(work, clear_after=False)
self.assertTrue(s2.event_count)
def test_clear_tasks(self):
s = State()
r = ev_snapshot(s)
r.play()
self.assertTrue(s.tasks)
s.clear_tasks(ready=False)
self.assertFalse(s.tasks)
def test_clear(self):
r = ev_snapshot(State())
r.play()
self.assertTrue(r.state.event_count)
self.assertTrue(r.state.workers)
self.assertTrue(r.state.tasks)
self.assertTrue(r.state.task_count)
r.state.clear()
self.assertFalse(r.state.event_count)
self.assertFalse(r.state.workers)
self.assertTrue(r.state.tasks)
self.assertFalse(r.state.task_count)
r.state.clear(False)
self.assertFalse(r.state.tasks)
def test_task_types(self):
r = ev_snapshot(State())
r.play()
self.assertEqual(sorted(r.state.task_types()), ['task1', 'task2'])
def test_tasks_by_timestamp(self):
r = ev_snapshot(State())
r.play()
self.assertEqual(len(list(r.state.tasks_by_timestamp())), 20)
def test_tasks_by_type(self):
r = ev_snapshot(State())
r.play()
self.assertEqual(len(list(r.state.tasks_by_type('task1'))), 10)
self.assertEqual(len(list(r.state.tasks_by_type('task2'))), 10)
def test_alive_workers(self):
r = ev_snapshot(State())
r.play()
self.assertEqual(len(r.state.alive_workers()), 3)
def test_tasks_by_worker(self):
r = ev_snapshot(State())
r.play()
self.assertEqual(len(list(r.state.tasks_by_worker('utest1'))), 10)
self.assertEqual(len(list(r.state.tasks_by_worker('utest2'))), 10)
def test_survives_unknown_worker_event(self):
s = State()
s.event({
'type': 'worker-unknown-event-xxx',
'foo': 'bar',
})
s.event({
'type': 'worker-unknown-event-xxx',
'hostname': 'xxx',
'foo': 'bar',
})
def test_survives_unknown_worker_leaving(self):
s = State(on_node_leave=Mock(name='on_node_leave'))
(worker, created), subject = s.event({
'type': 'worker-offline',
'hostname': 'unknown@vandelay.com',
'timestamp': time(),
'local_received': time(),
'clock': 301030134894833,
})
self.assertEqual(worker, Worker('unknown@vandelay.com'))
self.assertFalse(created)
self.assertEqual(subject, 'offline')
self.assertNotIn('unknown@vandelay.com', s.workers)
s.on_node_leave.assert_called_with(worker)
def test_on_node_join_callback(self):
s = State(on_node_join=Mock(name='on_node_join'))
(worker, created), subject = s.event({
'type': 'worker-online',
'hostname': 'george@vandelay.com',
'timestamp': time(),
'local_received': time(),
'clock': 34314,
})
self.assertTrue(worker)
self.assertTrue(created)
self.assertEqual(subject, 'online')
self.assertIn('george@vandelay.com', s.workers)
s.on_node_join.assert_called_with(worker)
def test_survives_unknown_task_event(self):
s = State()
s.event(
{
'type': 'task-unknown-event-xxx',
'foo': 'bar',
'uuid': 'x',
'hostname': 'y',
'timestamp': time(),
'local_received': time(),
'clock': 0,
},
)
def test_limits_maxtasks(self):
s = State(max_tasks_in_memory=1)
s.heap_multiplier = 2
s.event({
'type': 'task-unknown-event-xxx',
'foo': 'bar',
'uuid': 'x',
'hostname': 'y',
'clock': 3,
'timestamp': time(),
'local_received': time(),
})
s.event({
'type': 'task-unknown-event-xxx',
'foo': 'bar',
'uuid': 'y',
'hostname': 'y',
'clock': 4,
'timestamp': time(),
'local_received': time(),
})
s.event({
'type': 'task-unknown-event-xxx',
'foo': 'bar',
'uuid': 'z',
'hostname': 'y',
'clock': 5,
'timestamp': time(),
'local_received': time(),
})
self.assertEqual(len(s._taskheap), 2)
self.assertEqual(s._taskheap[0].clock, 4)
self.assertEqual(s._taskheap[1].clock, 5)
s._taskheap.append(s._taskheap[0])
self.assertTrue(list(s.tasks_by_time()))
def test_callback(self):
scratch = {}
def callback(state, event):
scratch['recv'] = True
s = State(callback=callback)
s.event({'type': 'worker-online'})
self.assertTrue(scratch.get('recv'))