Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

aroundthecode / lxml   python

Repository URL to install this package:

/ doctestcompare.py

"""
lxml-based doctest output comparison.

Note: normally, you should just import the `lxml.usedoctest` and
`lxml.html.usedoctest` modules from within a doctest, instead of this
one::

    >>> import lxml.usedoctest # for XML output

    >>> import lxml.html.usedoctest # for HTML output

To use this module directly, you must call ``lxmldoctest.install()``,
which will cause doctest to use this in all subsequent calls.

This changes the way output is checked and comparisons are made for
XML or HTML-like content.

XML or HTML content is noticed because the example starts with ``<``
(it's HTML if it starts with ``<html``).  You can also use the
``PARSE_HTML`` and ``PARSE_XML`` flags to force parsing.

Some rough wildcard-like things are allowed.  Whitespace is generally
ignored (except in attributes).  In text (attributes and text in the
body) you can use ``...`` as a wildcard.  In an example it also
matches any trailing tags in the element, though it does not match
leading tags.  You may create a tag ``<any>`` or include an ``any``
attribute in the tag.  An ``any`` tag matches any tag, while the
attribute matches any and all attributes.

When a match fails, the reformatted example and gotten text is
displayed (indented), and a rough diff-like output is given.  Anything
marked with ``+`` is in the output but wasn't supposed to be, and
similarly ``-`` means its in the example but wasn't in the output.

You can disable parsing on one line with ``# doctest:+NOPARSE_MARKUP``
"""

from lxml import etree
import sys
import re
import doctest
try:
    from html import escape as html_escape
except ImportError:
    from cgi import escape as html_escape

__all__ = ['PARSE_HTML', 'PARSE_XML', 'NOPARSE_MARKUP', 'LXMLOutputChecker',
           'LHTMLOutputChecker', 'install', 'temp_install']

try:
    _basestring = basestring
except NameError:
    _basestring = (str, bytes)

_IS_PYTHON_3 = sys.version_info[0] >= 3

PARSE_HTML = doctest.register_optionflag('PARSE_HTML')
PARSE_XML = doctest.register_optionflag('PARSE_XML')
NOPARSE_MARKUP = doctest.register_optionflag('NOPARSE_MARKUP')

OutputChecker = doctest.OutputChecker

def strip(v):
    if v is None:
        return None
    else:
        return v.strip()

def norm_whitespace(v):
    return _norm_whitespace_re.sub(' ', v)

_html_parser = etree.HTMLParser(recover=False, remove_blank_text=True)

def html_fromstring(html):
    return etree.fromstring(html, _html_parser)

# We use this to distinguish repr()s from elements:
_repr_re = re.compile(r'^<[^>]+ (at|object) ')
_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')

class LXMLOutputChecker(OutputChecker):

    empty_tags = (
        'param', 'img', 'area', 'br', 'basefont', 'input',
        'base', 'meta', 'link', 'col')

    def get_default_parser(self):
        return etree.XML

    def check_output(self, want, got, optionflags):
        alt_self = getattr(self, '_temp_override_self', None)
        if alt_self is not None:
            super_method = self._temp_call_super_check_output
            self = alt_self
        else:
            super_method = OutputChecker.check_output
        parser = self.get_parser(want, got, optionflags)
        if not parser:
            return super_method(
                self, want, got, optionflags)
        try:
            want_doc = parser(want)
        except etree.XMLSyntaxError:
            return False
        try:
            got_doc = parser(got)
        except etree.XMLSyntaxError:
            return False
        return self.compare_docs(want_doc, got_doc)

    def get_parser(self, want, got, optionflags):
        parser = None
        if NOPARSE_MARKUP & optionflags:
            return None
        if PARSE_HTML & optionflags:
            parser = html_fromstring
        elif PARSE_XML & optionflags:
            parser = etree.XML
        elif (want.strip().lower().startswith('<html')
              and got.strip().startswith('<html')):
            parser = html_fromstring
        elif (self._looks_like_markup(want)
              and self._looks_like_markup(got)):
            parser = self.get_default_parser()
        return parser

    def _looks_like_markup(self, s):
        s = s.strip()
        return (s.startswith('<')
                and not _repr_re.search(s))

    def compare_docs(self, want, got):
        if not self.tag_compare(want.tag, got.tag):
            return False
        if not self.text_compare(want.text, got.text, True):
            return False
        if not self.text_compare(want.tail, got.tail, True):
            return False
        if 'any' not in want.attrib:
            want_keys = sorted(want.attrib.keys())
            got_keys = sorted(got.attrib.keys())
            if want_keys != got_keys:
                return False
            for key in want_keys:
                if not self.text_compare(want.attrib[key], got.attrib[key], False):
                    return False
        if want.text != '...' or len(want):
            want_children = list(want)
            got_children = list(got)
            while want_children or got_children:
                if not want_children or not got_children:
                    return False
                want_first = want_children.pop(0)
                got_first = got_children.pop(0)
                if not self.compare_docs(want_first, got_first):
                    return False
                if not got_children and want_first.tail == '...':
                    break
        return True

    def text_compare(self, want, got, strip):
        want = want or ''
        got = got or ''
        if strip:
            want = norm_whitespace(want).strip()
            got = norm_whitespace(got).strip()
        want = '^%s$' % re.escape(want)
        want = want.replace(r'\.\.\.', '.*')
        if re.search(want, got):
            return True
        else:
            return False

    def tag_compare(self, want, got):
        if want == 'any':
            return True
        if (not isinstance(want, _basestring)
            or not isinstance(got, _basestring)):
            return want == got
        want = want or ''
        got = got or ''
        if want.startswith('{...}'):
            # Ellipsis on the namespace
            return want.split('}')[-1] == got.split('}')[-1]
        else:
            return want == got

    def output_difference(self, example, got, optionflags):
        want = example.want
        parser = self.get_parser(want, got, optionflags)
        errors = []
        if parser is not None:
            try:
                want_doc = parser(want)
            except etree.XMLSyntaxError:
                e = sys.exc_info()[1]
                errors.append('In example: %s' % e)
            try:
                got_doc = parser(got)
            except etree.XMLSyntaxError:
                e = sys.exc_info()[1]
                errors.append('In actual output: %s' % e)
        if parser is None or errors:
            value = OutputChecker.output_difference(
                self, example, got, optionflags)
            if errors:
                errors.append(value)
                return '\n'.join(errors)
            else:
                return value
        html = parser is html_fromstring
        diff_parts = []
        diff_parts.append('Expected:')
        diff_parts.append(self.format_doc(want_doc, html, 2))
        diff_parts.append('Got:')
        diff_parts.append(self.format_doc(got_doc, html, 2))
        diff_parts.append('Diff:')
        diff_parts.append(self.collect_diff(want_doc, got_doc, html, 2))
        return '\n'.join(diff_parts)

    def html_empty_tag(self, el, html=True):
        if not html:
            return False
        if el.tag not in self.empty_tags:
            return False
        if el.text or len(el):
            # This shouldn't happen (contents in an empty tag)
            return False
        return True

    def format_doc(self, doc, html, indent, prefix=''):
        parts = []
        if not len(doc):
            # No children...
            parts.append(' '*indent)
            parts.append(prefix)
            parts.append(self.format_tag(doc))
            if not self.html_empty_tag(doc, html):
                if strip(doc.text):
                    parts.append(self.format_text(doc.text))
                parts.append(self.format_end_tag(doc))
            if strip(doc.tail):
                parts.append(self.format_text(doc.tail))
            parts.append('\n')
            return ''.join(parts)
        parts.append(' '*indent)
        parts.append(prefix)
        parts.append(self.format_tag(doc))
        if not self.html_empty_tag(doc, html):
            parts.append('\n')
            if strip(doc.text):
                parts.append(' '*indent)
                parts.append(self.format_text(doc.text))
                parts.append('\n')
            for el in doc:
                parts.append(self.format_doc(el, html, indent+2))
            parts.append(' '*indent)
            parts.append(self.format_end_tag(doc))
            parts.append('\n')
        if strip(doc.tail):
            parts.append(' '*indent)
            parts.append(self.format_text(doc.tail))
            parts.append('\n')
        return ''.join(parts)

    def format_text(self, text, strip=True):
        if text is None:
            return ''
        if strip:
            text = text.strip()
        return html_escape(text, 1)

    def format_tag(self, el):
        attrs = []
        if isinstance(el, etree.CommentBase):
            # FIXME: probably PIs should be handled specially too?
            return '<!--'
        for name, value in sorted(el.attrib.items()):
            attrs.append('%s="%s"' % (name, self.format_text(value, False)))
        if not attrs:
            return '<%s>' % el.tag
        return '<%s %s>' % (el.tag, ' '.join(attrs))
    
    def format_end_tag(self, el):
        if isinstance(el, etree.CommentBase):
            # FIXME: probably PIs should be handled specially too?
            return '-->'
        return '</%s>' % el.tag

    def collect_diff(self, want, got, html, indent):
        parts = []
        if not len(want) and not len(got):
            parts.append(' '*indent)
            parts.append(self.collect_diff_tag(want, got))
            if not self.html_empty_tag(got, html):
                parts.append(self.collect_diff_text(want.text, got.text))
                parts.append(self.collect_diff_end_tag(want, got))
            parts.append(self.collect_diff_text(want.tail, got.tail))
            parts.append('\n')
            return ''.join(parts)
        parts.append(' '*indent)
        parts.append(self.collect_diff_tag(want, got))
        parts.append('\n')
        if strip(want.text) or strip(got.text):
            parts.append(' '*indent)
            parts.append(self.collect_diff_text(want.text, got.text))
            parts.append('\n')
        want_children = list(want)
        got_children = list(got)
        while want_children or got_children:
            if not want_children:
                parts.append(self.format_doc(got_children.pop(0), html, indent+2, '+'))
                continue
            if not got_children:
                parts.append(self.format_doc(want_children.pop(0), html, indent+2, '-'))
                continue
            parts.append(self.collect_diff(
                want_children.pop(0), got_children.pop(0), html, indent+2))
        parts.append(' '*indent)
        parts.append(self.collect_diff_end_tag(want, got))
        parts.append('\n')
        if strip(want.tail) or strip(got.tail):
            parts.append(' '*indent)
            parts.append(self.collect_diff_text(want.tail, got.tail))
            parts.append('\n')
        return ''.join(parts)

    def collect_diff_tag(self, want, got):
        if not self.tag_compare(want.tag, got.tag):
            tag = '%s (got: %s)' % (want.tag, got.tag)
        else:
            tag = got.tag
        attrs = []
        any = want.tag == 'any' or 'any' in want.attrib
        for name, value in sorted(got.attrib.items()):
            if name not in want.attrib and not any:
                attrs.append('+%s="%s"' % (name, self.format_text(value, False)))
            else:
                if name in want.attrib:
                    text = self.collect_diff_text(want.attrib[name], value, False)
                else:
                    text = self.format_text(value, False)
                attrs.append('%s="%s"' % (name, text))
        if not any:
            for name, value in sorted(want.attrib.items()):
Loading ...