Repository URL to install this package:
|
Version:
1.0.7 ▾
|
import functools
import inspect
import os
import pytest
__all__ = [
"check",
"equal",
"not_equal",
"is_",
"is_not",
"is_true",
"is_false",
"is_none",
"is_not_none",
"is_in",
"is_not_in",
"is_instance",
"is_not_instance",
"almost_equal",
"not_almost_equal",
"greater",
"greater_equal",
"less",
"less_equal",
"check_func",
"raises",
]
_stop_on_fail = False
_failures = []
def clear_failures():
global _failures
_failures = []
def get_failures():
return _failures
def set_stop_on_fail(stop_on_fail):
global _stop_on_fail
_stop_on_fail = stop_on_fail
class CheckContextManager(object):
msg = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
__tracebackhide__ = True
if exc_type is not None and issubclass(exc_type, AssertionError):
if _stop_on_fail:
self.msg = None
return
else:
if self.msg is not None:
log_failure(self.msg)
else:
log_failure(exc_val)
self.msg = None
return True
self.msg = None
def __call__(self, msg=None):
self.msg = msg
return self
check = CheckContextManager()
def check_func(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
__tracebackhide__ = True
try:
func(*args, **kwds)
return True
except AssertionError as e:
if _stop_on_fail:
raise e
log_failure(e)
return False
return wrapper
@check_func
def equal(a, b, msg=""):
assert a == b, msg
@check_func
def not_equal(a, b, msg=""):
assert a != b, msg
@check_func
def is_(a, b, msg=""):
assert a is b, msg
@check_func
def is_not(a, b, msg=""):
assert a is not b, msg
@check_func
def is_true(x, msg=""):
assert bool(x), msg
@check_func
def is_false(x, msg=""):
assert not bool(x), msg
@check_func
def is_none(x, msg=""):
assert x is None, msg
@check_func
def is_not_none(x, msg=""):
assert x is not None, msg
@check_func
def is_in(a, b, msg=""):
assert a in b, msg
@check_func
def is_not_in(a, b, msg=""):
assert a not in b, msg
@check_func
def is_instance(a, b, msg=""):
assert isinstance(a, b), msg
@check_func
def is_not_instance(a, b, msg=""):
assert not isinstance(a, b), msg
@check_func
def almost_equal(a, b, rel=None, abs=None, msg=""):
"""
for rel and abs tolerance, see:
See https://docs.pytest.org/en/latest/builtin.html#pytest.approx
"""
assert a == pytest.approx(b, rel, abs), msg
@check_func
def not_almost_equal(a, b, rel=None, abs=None, msg=""):
"""
for rel and abs tolerance, see:
See https://docs.pytest.org/en/latest/builtin.html#pytest.approx
"""
assert a != pytest.approx(b, rel, abs), msg
@check_func
def greater(a, b, msg=""):
assert a > b, msg
@check_func
def greater_equal(a, b, msg=""):
assert a >= b, msg
@check_func
def less(a, b, msg=""):
assert a < b, msg
@check_func
def less_equal(a, b, msg=""):
assert a <= b, msg
def raises(expected_exception, *args, **kwargs):
"""
Check that a given callable or context raises an error of a given type.
Can be used as either a context manager:
>>> with raises(AssertionError):
>>> raise AssertionError
or as a function:
>>> def raises_assert():
>>> raise AssertionError
>>> raises(AssertionError, raises_assert)
`expected_exception` follows the same format rules as the second argument
to `issubclass`, so multiple possible exception types can be used.
When args[0] is callable, the remainder of args and all of kwargs except
for any called `msg` are passed to args[0] as arguments.
Note that because `raises` is implemented using a context manager, the
usual control flow warnings apply: within the context, execution stops on
the first error encountered *and does not resume after this error has been
logged*. Therefore, the line you expect to raise an error must be the last
line of the context: any subsequent lines won't be executed. Pull such
lines out of the context if they don't raise errors, or use more calls to
`raises` if they do.
This function is modeled loosely after Pytest's own `raises`, except for
the latter's `match`-ing logic. We should strive to keep the call
signature of this `raises` as close as possible to the other `raises`.
"""
__tracebackhide__ = True
if isinstance(expected_exception, type):
excepted_exceptions = (expected_exception,)
else:
excepted_exceptions = expected_exception
assert all(
isinstance(exc, type) or issubclass(exc, BaseException)
for exc in excepted_exceptions
)
msg = kwargs.pop("msg", None)
if not args:
assert not kwargs, (
f"Unexpected kwargs for pytest_check.raises: {kwargs}"
)
return CheckRaisesContext(expected_exception, msg=msg)
else:
func = args[0]
assert callable(func)
with CheckRaisesContext(expected_exception, msg=msg):
func(*args[1:], **kwargs)
class CheckRaisesContext:
"""
Helper context for `raises` that can be parameterized by error type.
Note that CheckRaisesContext is instantiated whenever needed; it is not a
global variable like `check`. Therefore, we don't need to curate
`self.msg` in `__exit__` for this class like we do with
CheckContextManager.
"""
def __init__(self, *expected_excs, msg=None):
self.expected_excs = expected_excs
self.msg = msg
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
__tracebackhide__ = True
if exc_type is not None and issubclass(exc_type, self.expected_excs):
# This is the case where an error has occured within the context
# but it is the type we're expecting. Therefore we return True
# to silence this error and proceed with execution outside the
# context.
return True
if not _stop_on_fail:
# Returning something falsey here will cause the context
# manager to *not* suppress an exception not in
# `expected_excs`, thus allowing the higher-level Pytest
# context to handle it like any other unhandle exception during
# test execution, including display and tracebacks. That is the
# behavior we want when `_stop_on_fail` is True, so we let that
# case fall through. If *not* `_stop_on_fail`, then we want to
# log the error as a failed check but then continue execution
# without raising an error, hence `return True`.
log_failure(self.msg if self.msg else exc_val)
return True
def get_full_context(level):
(_, filename, line, funcname, contextlist) = inspect.stack()[level][0:5]
filename = os.path.relpath(filename)
context = contextlist[0].strip() if contextlist else ""
return (filename, line, funcname, context)
def log_failure(msg):
__tracebackhide__ = True
level = 3
pseudo_trace = []
func = ""
while "test_" not in func:
(file, line, func, context) = get_full_context(level)
if "site-packages" in file:
break
line = "{}:{} in {}() -> {}".format(file, line, func, context)
pseudo_trace.append(line)
level += 1
pseudo_trace_str = "\n".join(reversed(pseudo_trace))
entry = "FAILURE: {}\n{}".format(msg if msg else "", pseudo_trace_str)
_failures.append(entry)