from __future__ import absolute_import
import os
import signal
import sys
from contextlib import contextmanager
from time import time
from nose import SkipTest
from billiard.common import (
_shutdown_cleanup,
reset_signals,
restart_state,
)
from .utils import Case
try:
from unittest.mock import Mock, call, patch
except ImportError:
from mock import Mock, call, patch # noqa
def signo(name):
return getattr(signal, name)
@contextmanager
def termsigs(default, full):
from billiard import common
prev_def, common.TERMSIGS_DEFAULT = common.TERMSIGS_DEFAULT, default
prev_full, common.TERMSIGS_FULL = common.TERMSIGS_FULL, full
try:
yield
finally:
common.TERMSIGS_DEFAULT, common.TERMSIGS_FULL = prev_def, prev_full
class test_reset_signals(Case):
def setUp(self):
if sys.platform == 'win32':
raise SkipTest('win32: skip')
def test_shutdown_handler(self):
with patch('sys.exit') as exit:
_shutdown_cleanup(15, Mock())
self.assertTrue(exit.called)
self.assertEqual(os.WTERMSIG(exit.call_args[0][0]), 15)
def test_does_not_reset_ignored_signal(self, sigs=['SIGTERM']):
with self.assert_context(sigs, [], signal.SIG_IGN) as (_, SET):
self.assertFalse(SET.called)
def test_does_not_reset_if_current_is_None(self, sigs=['SIGTERM']):
with self.assert_context(sigs, [], None) as (_, SET):
self.assertFalse(SET.called)
def test_resets_for_SIG_DFL(self, sigs=['SIGTERM', 'SIGINT', 'SIGUSR1']):
with self.assert_context(sigs, [], signal.SIG_DFL) as (_, SET):
SET.assert_has_calls([
call(signo(sig), _shutdown_cleanup) for sig in sigs
])
def test_resets_for_obj(self, sigs=['SIGTERM', 'SIGINT', 'SIGUSR1']):
with self.assert_context(sigs, [], object()) as (_, SET):
SET.assert_has_calls([
call(signo(sig), _shutdown_cleanup) for sig in sigs
])
def test_handles_errors(self, sigs=['SIGTERM']):
for exc in (OSError(), AttributeError(),
ValueError(), RuntimeError()):
with self.assert_context(sigs, [], signal.SIG_DFL, exc) as (_, S):
self.assertTrue(S.called)
@contextmanager
def assert_context(self, default, full, get_returns=None, set_effect=None):
with termsigs(default, full):
with patch('signal.getsignal') as GET:
with patch('signal.signal') as SET:
GET.return_value = get_returns
SET.side_effect = set_effect
reset_signals()
GET.assert_has_calls([
call(signo(sig)) for sig in default
])
yield GET, SET
class test_restart_state(Case):
def test_raises(self):
s = restart_state(100, 1) # max 100 restarts in 1 second.
s.R = 99
s.step()
with self.assertRaises(s.RestartFreqExceeded):
s.step()
def test_time_passed_resets_counter(self):
s = restart_state(100, 10)
s.R, s.T = 100, time()
with self.assertRaises(s.RestartFreqExceeded):
s.step()
s.R, s.T = 100, time()
s.step(time() + 20)
self.assertEqual(s.R, 1)