from __future__ import absolute_import
import os
import base64
from kombu.serialization import registry
from celery.exceptions import SecurityError
from celery.security.serialization import SecureSerializer, register_auth
from celery.security.certificate import Certificate, CertStore
from celery.security.key import PrivateKey
from . import CERT1, CERT2, KEY1, KEY2
from .case import SecurityCase
class test_SecureSerializer(SecurityCase):
def _get_s(self, key, cert, certs):
store = CertStore()
for c in certs:
store.add_cert(Certificate(c))
return SecureSerializer(PrivateKey(key), Certificate(cert), store)
def test_serialize(self):
s = self._get_s(KEY1, CERT1, [CERT1])
self.assertEqual(s.deserialize(s.serialize('foo')), 'foo')
def test_deserialize(self):
s = self._get_s(KEY1, CERT1, [CERT1])
self.assertRaises(SecurityError, s.deserialize, 'bad data')
def test_unmatched_key_cert(self):
s = self._get_s(KEY1, CERT2, [CERT1, CERT2])
self.assertRaises(SecurityError,
s.deserialize, s.serialize('foo'))
def test_unknown_source(self):
s1 = self._get_s(KEY1, CERT1, [CERT2])
s2 = self._get_s(KEY1, CERT1, [])
self.assertRaises(SecurityError,
s1.deserialize, s1.serialize('foo'))
self.assertRaises(SecurityError,
s2.deserialize, s2.serialize('foo'))
def test_self_send(self):
s1 = self._get_s(KEY1, CERT1, [CERT1])
s2 = self._get_s(KEY1, CERT1, [CERT1])
self.assertEqual(s2.deserialize(s1.serialize('foo')), 'foo')
def test_separate_ends(self):
s1 = self._get_s(KEY1, CERT1, [CERT2])
s2 = self._get_s(KEY2, CERT2, [CERT1])
self.assertEqual(s2.deserialize(s1.serialize('foo')), 'foo')
def test_register_auth(self):
register_auth(KEY1, CERT1, '')
self.assertIn('application/data', registry._decoders)
def test_lots_of_sign(self):
for i in range(1000):
rdata = base64.urlsafe_b64encode(os.urandom(265))
s = self._get_s(KEY1, CERT1, [CERT1])
self.assertEqual(s.deserialize(s.serialize(rdata)), rdata)