From c3bb9588651969c3b6161aa7530d8d4becd2753f Mon Sep 17 00:00:00 2001 From: Joffrey F Date: Fri, 4 May 2018 14:06:03 -0700 Subject: [PATCH] Ignore default platform if API version doesn't support platform param Signed-off-by: Joffrey F --- compose/project.py | 3 +- compose/service.py | 18 ++++++++---- tests/unit/project_test.py | 9 +++--- tests/unit/service_test.py | 57 +++++++++++++++++++++++++++++++++++++- 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/compose/project.py b/compose/project.py index 924390b4e..c27794fc3 100644 --- a/compose/project.py +++ b/compose/project.py @@ -128,7 +128,8 @@ class Project(object): volumes_from=volumes_from, secrets=secrets, pid_mode=pid_mode, - platform=service_dict.pop('platform', default_platform), + platform=service_dict.pop('platform', None), + default_platform=default_platform, **service_dict) ) diff --git a/compose/service.py b/compose/service.py index ae9e0bb08..4ff56eea7 100644 --- a/compose/service.py +++ b/compose/service.py @@ -172,6 +172,7 @@ class Service(object): secrets=None, scale=None, pid_mode=None, + default_platform=None, **options ): self.name = name @@ -185,6 +186,7 @@ class Service(object): self.networks = networks or {} self.secrets = secrets or [] self.scale_num = scale or 1 + self.default_platform = default_platform self.options = options def __repr__(self): @@ -358,6 +360,13 @@ class Service(object): def image_name(self): return self.options.get('image', '{s.project}_{s.name}'.format(s=self)) + @property + def platform(self): + platform = self.options.get('platform') + if not platform and version_gte(self.client.api_version, '1.35'): + platform = self.default_platform + return platform + def convergence_plan(self, strategy=ConvergenceStrategy.changed): containers = self.containers(stopped=True) @@ -1018,8 +1027,7 @@ class Service(object): if not six.PY3 and not IS_WINDOWS_PLATFORM: path = path.encode('utf8') - platform = self.options.get('platform') - if platform and version_lt(self.client.api_version, '1.35'): + if self.platform and version_lt(self.client.api_version, '1.35'): raise OperationFailedError( 'Impossible to perform platform-targeted builds for API version < 1.35' ) @@ -1044,7 +1052,7 @@ class Service(object): }, gzip=gzip, isolation=build_opts.get('isolation', self.options.get('isolation', None)), - platform=platform, + platform=self.platform, ) try: @@ -1150,14 +1158,14 @@ class Service(object): kwargs = { 'tag': tag or 'latest', 'stream': True, - 'platform': self.options.get('platform'), + 'platform': self.platform, } if not silent: log.info('Pulling %s (%s%s%s)...' % (self.name, repo, separator, tag)) if kwargs['platform'] and version_lt(self.client.api_version, '1.35'): raise OperationFailedError( - 'Impossible to perform platform-targeted builds for API version < 1.35' + 'Impossible to perform platform-targeted pulls for API version < 1.35' ) try: output = self.client.pull(repo, **kwargs) diff --git a/tests/unit/project_test.py b/tests/unit/project_test.py index 1b6b6651f..1cc841814 100644 --- a/tests/unit/project_test.py +++ b/tests/unit/project_test.py @@ -29,6 +29,7 @@ class ProjectTest(unittest.TestCase): def setUp(self): self.mock_client = mock.create_autospec(docker.APIClient) self.mock_client._general_configs = {} + self.mock_client.api_version = docker.constants.DEFAULT_DOCKER_API_VERSION def test_from_config_v1(self): config = Config( @@ -578,21 +579,21 @@ class ProjectTest(unittest.TestCase): ) project = Project.from_config(name='test', client=self.mock_client, config_data=config_data) - assert project.get_service('web').options.get('platform') is None + assert project.get_service('web').platform is None project = Project.from_config( name='test', client=self.mock_client, config_data=config_data, default_platform='windows' ) - assert project.get_service('web').options.get('platform') == 'windows' + assert project.get_service('web').platform == 'windows' service_config['platform'] = 'linux/s390x' project = Project.from_config(name='test', client=self.mock_client, config_data=config_data) - assert project.get_service('web').options.get('platform') == 'linux/s390x' + assert project.get_service('web').platform == 'linux/s390x' project = Project.from_config( name='test', client=self.mock_client, config_data=config_data, default_platform='windows' ) - assert project.get_service('web').options.get('platform') == 'linux/s390x' + assert project.get_service('web').platform == 'linux/s390x' @mock.patch('compose.parallel.ParallelStreamWriter._write_noansi') def test_error_parallel_pull(self, mock_write): diff --git a/tests/unit/service_test.py b/tests/unit/service_test.py index d50db9044..f5a35d814 100644 --- a/tests/unit/service_test.py +++ b/tests/unit/service_test.py @@ -446,6 +446,20 @@ class ServiceTest(unittest.TestCase): with pytest.raises(OperationFailedError): service.pull() + def test_pull_image_with_default_platform(self): + self.mock_client.api_version = '1.35' + + service = Service( + 'foo', client=self.mock_client, image='someimage:sometag', + default_platform='linux' + ) + assert service.platform == 'linux' + service.pull() + + assert self.mock_client.pull.call_count == 1 + call_args = self.mock_client.pull.call_args + assert call_args[1]['platform'] == 'linux' + @mock.patch('compose.service.Container', autospec=True) def test_recreate_container(self, _): mock_container = mock.create_autospec(Container) @@ -538,7 +552,7 @@ class ServiceTest(unittest.TestCase): assert self.mock_client.build.call_count == 1 assert not self.mock_client.build.call_args[1]['pull'] - def test_build_does_with_platform(self): + def test_build_with_platform(self): self.mock_client.api_version = '1.35' self.mock_client.build.return_value = [ b'{"stream": "Successfully built 12345"}', @@ -551,6 +565,47 @@ class ServiceTest(unittest.TestCase): call_args = self.mock_client.build.call_args assert call_args[1]['platform'] == 'linux' + def test_build_with_default_platform(self): + self.mock_client.api_version = '1.35' + self.mock_client.build.return_value = [ + b'{"stream": "Successfully built 12345"}', + ] + + service = Service( + 'foo', client=self.mock_client, build={'context': '.'}, + default_platform='linux' + ) + assert service.platform == 'linux' + service.build() + + assert self.mock_client.build.call_count == 1 + call_args = self.mock_client.build.call_args + assert call_args[1]['platform'] == 'linux' + + def test_service_platform_precedence(self): + self.mock_client.api_version = '1.35' + + service = Service( + 'foo', client=self.mock_client, platform='linux/arm', + default_platform='osx' + ) + assert service.platform == 'linux/arm' + + def test_service_ignore_default_platform_with_unsupported_api(self): + self.mock_client.api_version = '1.32' + self.mock_client.build.return_value = [ + b'{"stream": "Successfully built 12345"}', + ] + + service = Service( + 'foo', client=self.mock_client, default_platform='windows', build={'context': '.'} + ) + assert service.platform is None + service.build() + assert self.mock_client.build.call_count == 1 + call_args = self.mock_client.build.call_args + assert call_args[1]['platform'] is None + def test_build_with_override_build_args(self): self.mock_client.build.return_value = [ b'{"stream": "Successfully built 12345"}',