from __future__ import absolute_import, unicode_literals
import inspect
import logging
import os
import platform
import sys
import time
import types
from contextlib import contextmanager
from functools import wraps
from six import reraise, string_types, iteritems as items
from six.moves import builtins
from .utils import WhateverIO, decorator, get_logger_handlers, noop
try:
from unittest import mock
except ImportError:
import mock # noqa
PY3 = sys.version_info[0] >= 3
if PY3:
open_fqdn = 'builtins.open'
module_name_t = str
else:
open_fqdn = '__builtin__.open' # noqa
module_name_t = bytes # noqa
__all__ = [
'ANY', 'ContextMock', 'MagicMock', 'Mock', 'MockCallbacks',
'call', 'patch', 'sentinel',
'wrap_logger', 'environ', 'sleepdeprived', 'mask_modules',
'stdouts', 'replace_module_value', 'sys_version', 'pypy_version',
'platform_pyimp', 'sys_platform', 'reset_modules', 'module',
'open', 'restore_logging', 'module_exists', 'create_patcher',
]
ANY = mock.ANY
call = mock.call
patch = mock.patch
sentinel = mock.sentinel
[docs]def create_patcher(*partial_path):
def patcher(name, **kwargs):
return patch(".".join(partial_path + (name, )), **kwargs)
return patcher
class MockMixin(object):
def on_nth_call_do(self, side_effect, n=1):
"""Change Mock side effect after ``n`` calls.
Example::
mock.on_nth_call_do(RuntimeError, n=3)
"""
def on_call(*args, **kwargs):
if self.call_count >= n:
self.side_effect = side_effect
return self.return_value
self.side_effect = on_call
return self
def on_nth_call_return(self, retval, n=1):
"""Change Mock to return specific return value after ``n`` calls.
Example::
mock.on_nth_call_return('STOP', n=3)
"""
def on_call(*args, **kwargs):
if self.call_count >= n:
self.return_value = retval
return self.return_value
self.side_effect = on_call
return self
def _mock_update_attributes(self, attrs={}, **kwargs):
for key, value in items(attrs):
setattr(self, key, value)
def assert_not_called(_mock_self):
"""assert that the mock was never called."""
self = _mock_self
if self.call_count != 0:
msg = ("Expected '%s' to not have been called. Called %s times." %
(self._mock_name or 'mock', self.call_count))
raise AssertionError(msg)
def assert_called(_mock_self):
"""assert that the mock was called at least once."""
self = _mock_self
if self.call_count == 0:
msg = ("Expected '%s' to have been called." %
self._mock_name or 'mock')
raise AssertionError(msg)
def assert_called_once(_mock_self):
"""assert that the mock was called only once."""
self = _mock_self
if not self.call_count == 1:
msg = ("Expected '%s' to have been called once. Called %s times." %
(self._mock_name or 'mock', self.call_count))
raise AssertionError(msg)
[docs]class Mock(mock.Mock, MockMixin):
def __init__(self, *args, **kwargs):
super(Mock, self).__init__(*args, **kwargs)
self._mock_update_attributes(**kwargs)
[docs]class MagicMock(mock.MagicMock, MockMixin):
def __init__(self, *args, **kwargs):
super(MagicMock, self).__init__(*args, **kwargs)
self._mock_update_attributes(**kwargs)
class _ContextMock(Mock):
"""Dummy class implementing __enter__ and __exit__
as the :keyword:`with` statement requires these to be implemented
in the class, not just the instance."""
def __enter__(self):
return self
def __exit__(self, *exc_info):
pass
[docs]def ContextMock(*args, **kwargs):
"""Mock that mocks :keyword:`with` statement contexts."""
obj = _ContextMock(*args, **kwargs)
obj.attach_mock(_ContextMock(), '__enter__')
obj.attach_mock(_ContextMock(), '__exit__')
obj.__enter__.return_value = obj
# if __exit__ return a value the exception is ignored,
# so it must return None here.
obj.__exit__.return_value = None
return obj
def _bind(f, o):
@wraps(f)
def bound_meth(*fargs, **fkwargs):
return f(o, *fargs, **fkwargs)
return bound_meth
if PY3: # pragma: no cover
def _get_class_fun(meth):
return meth
def module_name(s):
if isinstance(s, bytes):
return s.decode()
return s
else:
def _get_class_fun(meth): # noqa
return meth.__func__
def module_name(s): # noqa
if isinstance(s, unicode):
return s.encode()
return s
[docs]class MockCallbacks(object):
def __new__(cls, *args, **kwargs):
r = Mock(name=cls.__name__)
_get_class_fun(cls.__init__)(r, *args, **kwargs)
for key, value in items(vars(cls)):
if key not in ('__dict__', '__weakref__', '__new__', '__init__'):
if inspect.ismethod(value) or inspect.isfunction(value):
r.__getattr__(key).side_effect = _bind(value, r)
else:
r.__setattr__(key, value)
return r
@decorator
[docs]def wrap_logger(logger, loglevel=logging.ERROR):
"""Wrap :class:`logging.Logger` with a StringIO() handler.
yields a StringIO handle.
Example::
with mock.wrap_logger(logger, loglevel=logging.DEBUG) as sio:
...
sio.getvalue()
"""
old_handlers = get_logger_handlers(logger)
sio = WhateverIO()
siohandler = logging.StreamHandler(sio)
logger.handlers = [siohandler]
try:
yield sio
finally:
logger.handlers = old_handlers
@decorator
[docs]def environ(env_name, env_value):
"""Mock environment variable value.
Example::
@mock.environ('DJANGO_SETTINGS_MODULE', 'proj.settings')
def test_other_settings(self):
...
"""
sentinel = object()
prev_val = os.environ.get(env_name, sentinel)
os.environ[env_name] = env_value
try:
yield
finally:
if prev_val is sentinel:
os.environ.pop(env_name, None)
else:
os.environ[env_name] = prev_val
@decorator
[docs]def sleepdeprived(module=time):
"""Mock time.sleep to do nothing.
Example::
@mock.sleepdeprived() # < patches time.sleep
@mock.sleepdeprived(celery.result) # < patches celery.result.sleep
"""
old_sleep, module.sleep = module.sleep, noop
try:
yield
finally:
module.sleep = old_sleep
# Taken from
# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
@decorator
[docs]def mask_modules(*modnames):
"""Ban some modules from being importable inside the context
For example::
>>> with mask_modules('sys'):
... try:
... import sys
... except ImportError:
... print('sys not found')
sys not found
>>> import sys # noqa
>>> sys.version
(2, 5, 2, 'final', 0)
Or as a decorator::
@mask_modules('sys')
def test_foo(self):
...
"""
realimport = builtins.__import__
def myimp(name, *args, **kwargs):
if name in modnames:
raise ImportError('No module named %s' % name)
else:
return realimport(name, *args, **kwargs)
builtins.__import__ = myimp
try:
yield
finally:
builtins.__import__ = realimport
@decorator
[docs]def stdouts():
"""Override `sys.stdout` and `sys.stderr` with `StringIO`
instances.
Decorator example::
@mock.stdouts
def test_foo(self, stdout, stderr):
something()
self.assertIn('foo', stdout.getvalue())
Context example::
with mock.stdouts() as (stdout, stderr):
something()
self.assertIn('foo', stdout.getvalue())
"""
prev_out, prev_err = sys.stdout, sys.stderr
prev_rout, prev_rerr = sys.__stdout__, sys.__stderr__
mystdout, mystderr = WhateverIO(), WhateverIO()
sys.stdout = sys.__stdout__ = mystdout
sys.stderr = sys.__stderr__ = mystderr
try:
yield mystdout, mystderr
finally:
sys.stdout = prev_out
sys.stderr = prev_err
sys.__stdout__ = prev_rout
sys.__stderr__ = prev_rerr
@decorator
[docs]def replace_module_value(module, name, value=None):
"""Mock module value, given a module, attribute name and value.
Decorator example::
@mock.replace_module_value(module, 'CONSTANT', 3.03)
def test_foo(self):
...
Context example::
with mock.replace_module_value(module, 'CONSTANT', 3.03):
...
"""
has_prev = hasattr(module, name)
prev = getattr(module, name, None)
if value:
setattr(module, name, value)
else:
try:
delattr(module, name)
except AttributeError:
pass
try:
yield
finally:
if prev is not None:
setattr(module, name, prev)
if not has_prev:
try:
delattr(module, name)
except AttributeError:
pass
[docs]def sys_version(value=None):
"""Mock :data:`sys.version_info`
Decorator example::
@mock.sys_version((3, 6, 1))
def test_foo(self):
...
Context example::
with mock.sys_version((3, 6, 1)):
...
"""
return replace_module_value(sys, 'version_info', value)
[docs]def pypy_version(value=None):
"""Mock :data:`sys.pypy_version_info`
Decorator example::
@mock.pypy_version((3, 6, 1))
def test_foo(self):
...
Context example::
with mock.pypy_version((3, 6, 1)):
...
"""
return replace_module_value(sys, 'pypy_version_info', value)
@decorator
@decorator
[docs]def reset_modules(*modules):
"""Remove modules from :data:`sys.modules` by name,
and reset back again when the test/context returns.
Decorator example::
@mock.reset_modules('celery.result', 'celery.app.base')
def test_foo(self):
pass
Context example::
with mock.reset_modules('celery.result', 'celery.app.base'):
pass
"""
prev = dict((k, sys.modules.pop(k))
for k in modules if k in sys.modules)
try:
yield
finally:
sys.modules.update(prev)
@decorator
[docs]def module(*names):
"""Mock one or modules such that every attribute is a :class:`Mock`."""
prev = {}
class MockModule(types.ModuleType):
def __getattr__(self, attr):
setattr(self, attr, Mock())
return types.ModuleType.__getattribute__(self, attr)
mods = []
for name in names:
try:
prev[name] = sys.modules[name]
except KeyError:
pass
mod = sys.modules[name] = MockModule(module_name(name))
mods.append(mod)
try:
yield mods
finally:
for name in names:
try:
sys.modules[name] = prev[name]
except KeyError:
try:
del(sys.modules[name])
except KeyError:
pass
@contextmanager
def mock_context(mock, typ=Mock):
context = mock.return_value = Mock()
context.__enter__ = typ()
context.__exit__ = typ()
def on_exit(*x):
if x[0]:
reraise(x[0], x[1], x[2])
context.__exit__.side_effect = on_exit
context.__enter__.return_value = context
try:
yield context
finally:
context.reset()
@decorator
[docs]def open(typ=WhateverIO, side_effect=None):
"""Patch builtins.open so that it returns StringIO object.
:param typ: File object for open to return.
Defaults to :class:`WhateverIO` which is the bastard child
of :class:`io.StringIO` and :class:`io.BytesIO` accepting
both bytes and unicode input.
:param side_effect: Additional side effect for when the open context
is entered.
Decorator example::
@mock.open()
def test_foo(self, open_fh):
something_opening_and_writing_a_file()
self.assertIn('foo', open_fh.getvalue())
Context example::
with mock.open(io.BytesIO) as open_fh:
something_opening_and_writing_bytes_to_a_file()
self.assertIn(b'foo', open_fh.getvalue())
"""
with patch(open_fqdn) as open_:
with mock_context(open_) as context:
if side_effect is not None:
context.__enter__.side_effect = side_effect
val = context.__enter__.return_value = typ()
val.__exit__ = Mock()
yield val
@decorator
[docs]def restore_logging():
"""Restore root logger handlers after test returns.
Decorator example::
@mock.restore_logging()
def test_foo(self):
setup_logging()
Context example::
with mock.restore_logging():
setup_logging()
"""
outs = sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__
root = logging.getLogger()
level = root.level
handlers = root.handlers
try:
yield
finally:
sys.stdout, sys.stderr, sys.__stdout__, sys.__stderr__ = outs
root.level = level
root.handlers[:] = handlers
@decorator
[docs]def module_exists(*modules):
"""Patch one or more modules to ensure they exist.
A module name with multiple paths (e.g. gevent.monkey) will
ensure all parent modules are also patched (``gevent`` +
``gevent.monkey``).
Decorator example::
@mock.module_exists('gevent.monkey')
def test_foo(self):
pass
Context example::
with mock.module_exists('gevent.monkey'):
gevent.monkey.patch_all = Mock(name='patch_all')
...
"""
gen = []
old_modules = []
for module in modules:
if isinstance(module, string_types):
module = types.ModuleType(module_name(module))
gen.append(module)
if module.__name__ in sys.modules:
old_modules.append(sys.modules[module.__name__])
sys.modules[module.__name__] = module
name = module.__name__
if '.' in name:
parent, _, attr = name.rpartition('.')
setattr(sys.modules[parent], attr, module)
try:
yield
finally:
for module in gen:
sys.modules.pop(module.__name__, None)
for module in old_modules:
sys.modules[module.__name__] = module