from __future__ import unicode_literals
import difflib
import errno
import json
import os
import posixpath
import socket
import sys
import threading
import unittest
import warnings
from collections import Counter
from contextlib import contextmanager
from copy import copy
from functools import wraps
from unittest.util import safe_repr
from django.apps import apps
from django.conf import settings
from django.core import mail
from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.core.files import locks
from django.core.handlers.wsgi import WSGIHandler, get_path_info
from django.core.management import call_command
from django.core.management.color import no_style
from django.core.management.sql import emit_post_migrate_signal
from django.core.servers.basehttp import WSGIRequestHandler, WSGIServer
from django.core.urlresolvers import clear_url_caches, set_urlconf
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
from django.forms.fields import CharField
from django.http import QueryDict
from django.test.client import Client
from django.test.html import HTMLParseError, parse_html
from django.test.signals import setting_changed, template_rendered
from django.test.utils import (
CaptureQueriesContext, ContextList, compare_xml, modify_settings,
override_settings,
)
from django.utils import six
from django.utils.decorators import classproperty
from django.utils.deprecation import (
RemovedInDjango20Warning, RemovedInDjango110Warning,
)
from django.utils.encoding import force_text
from django.utils.six.moves.urllib.parse import (
unquote, urljoin, urlparse, urlsplit, urlunsplit,
)
from django.utils.six.moves.urllib.request import url2pathname
from django.views.static import serve
__all__ = ('TestCase', 'TransactionTestCase',
'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
def to_list(value):
"""
Puts value into a list if it's not already one.
Returns an empty list if value is None.
"""
if value is None:
value = []
elif not isinstance(value, list):
value = [value]
return value
def assert_and_parse_html(self, html, user_msg, msg):
try:
dom = parse_html(html)
except HTMLParseError as e:
standardMsg = '%s\n%s' % (msg, e.msg)
self.fail(self._formatMessage(user_msg, standardMsg))
return dom
class _AssertNumQueriesContext(CaptureQueriesContext):
def __init__(self, test_case, num, connection):
self.test_case = test_case
self.num = num
super(_AssertNumQueriesContext, self).__init__(connection)
def __exit__(self, exc_type, exc_value, traceback):
super(_AssertNumQueriesContext, self).__exit__(exc_type, exc_value, traceback)
if exc_type is not None:
return
executed = len(self)
self.test_case.assertEqual(
executed, self.num,
"%d queries executed, %d expected\nCaptured queries were:\n%s" % (
executed, self.num,
'\n'.join(
query['sql'] for query in self.captured_queries
)
)
)
class _AssertTemplateUsedContext(object):
def __init__(self, test_case, template_name):
self.test_case = test_case
self.template_name = template_name
self.rendered_templates = []
self.rendered_template_names = []
self.context = ContextList()
def on_template_render(self, sender, signal, template, context, **kwargs):
self.rendered_templates.append(template)
self.rendered_template_names.append(template.name)
self.context.append(copy(context))
def test(self):
return self.template_name in self.rendered_template_names
def message(self):
return '%s was not rendered.' % self.template_name
def __enter__(self):
template_rendered.connect(self.on_template_render)
return self
def __exit__(self, exc_type, exc_value, traceback):
template_rendered.disconnect(self.on_template_render)
if exc_type is not None:
return
if not self.test():
message = self.message()
if len(self.rendered_templates) == 0:
message += ' No template was rendered.'
else:
message += ' Following templates were rendered: %s' % (
', '.join(self.rendered_template_names))
self.test_case.fail(message)
class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
def test(self):
return self.template_name not in self.rendered_template_names
def message(self):
return '%s was rendered.' % self.template_name
class _CursorFailure(object):
def __init__(self, cls_name, wrapped):
self.cls_name = cls_name
self.wrapped = wrapped
def __call__(self):
raise AssertionError(
"Database queries aren't allowed in SimpleTestCase. "
"Either use TestCase or TransactionTestCase to ensure proper test isolation or "
"set %s.allow_database_queries to True to silence this failure." % self.cls_name
)
class SimpleTestCase(unittest.TestCase):
# The class we'll use for the test client self.client.
# Can be overridden in derived classes.
client_class = Client
_overridden_settings = None
_modified_settings = None
# Tests shouldn't be allowed to query the database since
# this base class doesn't enforce any isolation.
allow_database_queries = False
@classmethod
def setUpClass(cls):
super(SimpleTestCase, cls).setUpClass()
if cls._overridden_settings:
cls._cls_overridden_context = override_settings(**cls._overridden_settings)
cls._cls_overridden_context.enable()
if cls._modified_settings:
cls._cls_modified_context = modify_settings(cls._modified_settings)
cls._cls_modified_context.enable()
if not cls.allow_database_queries:
for alias in connections:
connection = connections[alias]
connection.cursor = _CursorFailure(cls.__name__, connection.cursor)
@classmethod
def tearDownClass(cls):
if not cls.allow_database_queries:
for alias in connections:
connection = connections[alias]
connection.cursor = connection.cursor.wrapped
if hasattr(cls, '_cls_modified_context'):
cls._cls_modified_context.disable()
delattr(cls, '_cls_modified_context')
if hasattr(cls, '_cls_overridden_context'):
cls._cls_overridden_context.disable()
delattr(cls, '_cls_overridden_context')
super(SimpleTestCase, cls).tearDownClass()
def __call__(self, result=None):
"""
Wrapper around default __call__ method to perform common Django test
set up. This means that user-defined Test Cases aren't required to
include a call to super().setUp().
"""
testMethod = getattr(self, self._testMethodName)
skipped = (getattr(self.__class__, "__unittest_skip__", False) or
getattr(testMethod, "__unittest_skip__", False))
if not skipped:
try:
self._pre_setup()
except Exception:
result.addError(self, sys.exc_info())
return
super(SimpleTestCase, self).__call__(result)
if not skipped:
try:
self._post_teardown()
except Exception:
result.addError(self, sys.exc_info())
return
def _pre_setup(self):
"""Performs any pre-test setup. This includes:
* Creating a test client.
* If the class has a 'urls' attribute, replace ROOT_URLCONF with it.
* Clearing the mail test outbox.
"""
self.client = self.client_class()
self._urlconf_setup()
mail.outbox = []
def _urlconf_setup(self):
if hasattr(self, 'urls'):
warnings.warn(
"SimpleTestCase.urls is deprecated and will be removed in "
"Django 1.10. Use @override_settings(ROOT_URLCONF=...) "
"in %s instead." % self.__class__.__name__,
RemovedInDjango110Warning, stacklevel=2)
set_urlconf(None)
self._old_root_urlconf = settings.ROOT_URLCONF
settings.ROOT_URLCONF = self.urls
clear_url_caches()
def _post_teardown(self):
"""Performs any post-test things. This includes:
* Putting back the original ROOT_URLCONF if it was changed.
"""
self._urlconf_teardown()
def _urlconf_teardown(self):
if hasattr(self, '_old_root_urlconf'):
set_urlconf(None)
settings.ROOT_URLCONF = self._old_root_urlconf
clear_url_caches()
def settings(self, **kwargs):
"""
A context manager that temporarily sets a setting and reverts to the original value when exiting the context.
"""
return override_settings(**kwargs)
def modify_settings(self, **kwargs):
"""
A context manager that temporarily applies changes a list setting and
reverts back to the original value when exiting the context.
"""
return modify_settings(**kwargs)
def assertRedirects(self, response, expected_url, status_code=302,
target_status_code=200, host=None, msg_prefix='',
fetch_redirect_response=True):
"""Asserts that a response redirected to a specific URL, and that the
redirect URL can be loaded.
Note that assertRedirects won't work for external links since it uses
TestClient to do a request (use fetch_redirect_response=False to check
such links without fetching them).
"""
if host is not None:
warnings.warn(
"The host argument is deprecated and no longer used by assertRedirects",
RemovedInDjango20Warning, stacklevel=2
)
if msg_prefix:
msg_prefix += ": "
if hasattr(response, 'redirect_chain'):
# The request was a followed redirect
self.assertTrue(len(response.redirect_chain) > 0,
msg_prefix + "Response didn't redirect as expected: Response"
" code was %d (expected %d)" %
(response.status_code, status_code))
self.assertEqual(response.redirect_chain[0][1], status_code,
msg_prefix + "Initial response didn't redirect as expected:"
" Response code was %d (expected %d)" %
(response.redirect_chain[0][1], status_code))
url, status_code = response.redirect_chain[-1]
scheme, netloc, path, query, fragment = urlsplit(url)
self.assertEqual(response.status_code, target_status_code,
msg_prefix + "Response didn't redirect as expected: Final"
" Response code was %d (expected %d)" %
(response.status_code, target_status_code))
else:
# Not a followed redirect
self.assertEqual(response.status_code, status_code,
msg_prefix + "Response didn't redirect as expected: Response"
" code was %d (expected %d)" %
(response.status_code, status_code))
url = response.url
scheme, netloc, path, query, fragment = urlsplit(url)
# Prepend the request path to handle relative path redirects.
if not path.startswith('/'):
url = urljoin(response.request['PATH_INFO'], url)
path = urljoin(response.request['PATH_INFO'], path)
if fetch_redirect_response:
redirect_response = response.client.get(path, QueryDict(query),
secure=(scheme == 'https'))
# Get the redirection page, using the same client that was used
# to obtain the original response.
self.assertEqual(redirect_response.status_code, target_status_code,
msg_prefix + "Couldn't retrieve redirection page '%s':"
" response code was %d (expected %d)" %
(path, redirect_response.status_code, target_status_code))
if url != expected_url:
# For temporary backwards compatibility, try to compare with a relative url
e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
relative_url = urlunsplit(('', '', e_path, e_query, e_fragment))
if url == relative_url:
warnings.warn(
"assertRedirects had to strip the scheme and domain from the "
"expected URL, as it was always added automatically to URLs "
"before Django 1.9. Please update your expected URLs by "
"removing the scheme and domain.",
Loading ...