Convert some cli tests to pytest.

Signed-off-by: Daniel Nephin <dnephin@docker.com>
This commit is contained in:
Daniel Nephin 2016-03-08 14:42:51 -05:00
parent 9f9dcc098a
commit 886328640f
1 changed files with 35 additions and 32 deletions

View File

@ -3,6 +3,8 @@ from __future__ import unicode_literals
import logging import logging
import pytest
from compose import container from compose import container
from compose.cli.errors import UserError from compose.cli.errors import UserError
from compose.cli.formatter import ConsoleWarningFormatter from compose.cli.formatter import ConsoleWarningFormatter
@ -11,7 +13,6 @@ from compose.cli.main import convergence_strategy_from_opts
from compose.cli.main import setup_console_handler from compose.cli.main import setup_console_handler
from compose.service import ConvergenceStrategy from compose.service import ConvergenceStrategy
from tests import mock from tests import mock
from tests import unittest
def mock_container(service, number): def mock_container(service, number):
@ -22,7 +23,14 @@ def mock_container(service, number):
name_without_project='{0}_{1}'.format(service, number)) name_without_project='{0}_{1}'.format(service, number))
class CLIMainTestCase(unittest.TestCase): @pytest.fixture
def logging_handler():
stream = mock.Mock()
stream.isatty.return_value = True
return logging.StreamHandler(stream=stream)
class TestCLIMainTestCase(object):
def test_build_log_printer(self): def test_build_log_printer(self):
containers = [ containers = [
@ -34,7 +42,7 @@ class CLIMainTestCase(unittest.TestCase):
] ]
service_names = ['web', 'db'] service_names = ['web', 'db']
log_printer = build_log_printer(containers, service_names, True, False, {'follow': True}) log_printer = build_log_printer(containers, service_names, True, False, {'follow': True})
self.assertEqual(log_printer.containers, containers[:3]) assert log_printer.containers == containers[:3]
def test_build_log_printer_all_services(self): def test_build_log_printer_all_services(self):
containers = [ containers = [
@ -44,58 +52,53 @@ class CLIMainTestCase(unittest.TestCase):
] ]
service_names = [] service_names = []
log_printer = build_log_printer(containers, service_names, True, False, {'follow': True}) log_printer = build_log_printer(containers, service_names, True, False, {'follow': True})
self.assertEqual(log_printer.containers, containers) assert log_printer.containers == containers
class SetupConsoleHandlerTestCase(unittest.TestCase): class TestSetupConsoleHandlerTestCase(object):
def setUp(self): def test_with_tty_verbose(self, logging_handler):
self.stream = mock.Mock() setup_console_handler(logging_handler, True)
self.stream.isatty.return_value = True assert type(logging_handler.formatter) == ConsoleWarningFormatter
self.handler = logging.StreamHandler(stream=self.stream) assert '%(name)s' in logging_handler.formatter._fmt
assert '%(funcName)s' in logging_handler.formatter._fmt
def test_with_tty_verbose(self): def test_with_tty_not_verbose(self, logging_handler):
setup_console_handler(self.handler, True) setup_console_handler(logging_handler, False)
assert type(self.handler.formatter) == ConsoleWarningFormatter assert type(logging_handler.formatter) == ConsoleWarningFormatter
assert '%(name)s' in self.handler.formatter._fmt assert '%(name)s' not in logging_handler.formatter._fmt
assert '%(funcName)s' in self.handler.formatter._fmt assert '%(funcName)s' not in logging_handler.formatter._fmt
def test_with_tty_not_verbose(self): def test_with_not_a_tty(self, logging_handler):
setup_console_handler(self.handler, False) logging_handler.stream.isatty.return_value = False
assert type(self.handler.formatter) == ConsoleWarningFormatter setup_console_handler(logging_handler, False)
assert '%(name)s' not in self.handler.formatter._fmt assert type(logging_handler.formatter) == logging.Formatter
assert '%(funcName)s' not in self.handler.formatter._fmt
def test_with_not_a_tty(self):
self.stream.isatty.return_value = False
setup_console_handler(self.handler, False)
assert type(self.handler.formatter) == logging.Formatter
class ConvergeStrategyFromOptsTestCase(unittest.TestCase): class TestConvergeStrategyFromOptsTestCase(object):
def test_invalid_opts(self): def test_invalid_opts(self):
options = {'--force-recreate': True, '--no-recreate': True} options = {'--force-recreate': True, '--no-recreate': True}
with self.assertRaises(UserError): with pytest.raises(UserError):
convergence_strategy_from_opts(options) convergence_strategy_from_opts(options)
def test_always(self): def test_always(self):
options = {'--force-recreate': True, '--no-recreate': False} options = {'--force-recreate': True, '--no-recreate': False}
self.assertEqual( assert (
convergence_strategy_from_opts(options), convergence_strategy_from_opts(options) ==
ConvergenceStrategy.always ConvergenceStrategy.always
) )
def test_never(self): def test_never(self):
options = {'--force-recreate': False, '--no-recreate': True} options = {'--force-recreate': False, '--no-recreate': True}
self.assertEqual( assert (
convergence_strategy_from_opts(options), convergence_strategy_from_opts(options) ==
ConvergenceStrategy.never ConvergenceStrategy.never
) )
def test_changed(self): def test_changed(self):
options = {'--force-recreate': False, '--no-recreate': False} options = {'--force-recreate': False, '--no-recreate': False}
self.assertEqual( assert (
convergence_strategy_from_opts(options), convergence_strategy_from_opts(options) ==
ConvergenceStrategy.changed ConvergenceStrategy.changed
) )