diff --git a/compose/config/config.py b/compose/config/config.py index 864bc7e90..e784be176 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -97,6 +97,7 @@ DOCKER_CONFIG_KEYS = [ 'privileged', 'read_only', 'restart', + 'runtime', 'secrets', 'security_opt', 'shm_size', diff --git a/compose/config/config_schema_v2.3.json b/compose/config/config_schema_v2.3.json index 6f923871b..cedc2dae6 100644 --- a/compose/config/config_schema_v2.3.json +++ b/compose/config/config_schema_v2.3.json @@ -261,6 +261,7 @@ "privileged": {"type": "boolean"}, "read_only": {"type": "boolean"}, "restart": {"type": "string"}, + "runtime": {"type": "string"}, "scale": {"type": "integer"}, "security_opt": {"type": "array", "items": {"type": "string"}, "uniqueItems": true}, "shm_size": {"type": ["number", "string"]}, diff --git a/compose/service.py b/compose/service.py index 07db3ac5f..3e492267f 100644 --- a/compose/service.py +++ b/compose/service.py @@ -87,6 +87,7 @@ HOST_CONFIG_KEYS = [ 'pids_limit', 'privileged', 'restart', + 'runtime', 'security_opt', 'shm_size', 'storage_opt', @@ -858,6 +859,7 @@ class Service(object): dns_opt=options.get('dns_opt'), dns_search=options.get('dns_search'), restart_policy=options.get('restart'), + runtime=options.get('runtime'), cap_add=options.get('cap_add'), cap_drop=options.get('cap_drop'), mem_limit=options.get('mem_limit'), diff --git a/tests/integration/project_test.py b/tests/integration/project_test.py index 953dd52be..e966f8d88 100644 --- a/tests/integration/project_test.py +++ b/tests/integration/project_test.py @@ -22,6 +22,7 @@ from compose.config.types import VolumeSpec from compose.const import COMPOSEFILE_V2_0 as V2_0 from compose.const import COMPOSEFILE_V2_1 as V2_1 from compose.const import COMPOSEFILE_V2_2 as V2_2 +from compose.const import COMPOSEFILE_V2_3 as V2_3 from compose.const import COMPOSEFILE_V3_1 as V3_1 from compose.const import LABEL_PROJECT from compose.const import LABEL_SERVICE @@ -31,10 +32,12 @@ from compose.errors import NoHealthCheckConfigured from compose.project import Project from compose.project import ProjectError from compose.service import ConvergenceStrategy +from tests.integration.testcases import if_runtime_available from tests.integration.testcases import is_cluster from tests.integration.testcases import no_cluster from tests.integration.testcases import v2_1_only from tests.integration.testcases import v2_2_only +from tests.integration.testcases import v2_3_only from tests.integration.testcases import v2_only from tests.integration.testcases import v3_only @@ -971,6 +974,66 @@ class ProjectTest(DockerClientTestCase): with self.assertRaises(ProjectError): project.up() + @v2_3_only() + def test_up_with_runtime(self): + self.require_api_version('1.30') + config_data = build_config( + version=V2_3, + services=[{ + 'name': 'web', + 'image': 'busybox:latest', + 'runtime': 'runc' + }], + ) + project = Project.from_config( + client=self.client, + name='composetest', + config_data=config_data + ) + project.up(detached=True) + service_container = project.get_service('web').containers(stopped=True)[0] + assert service_container.inspect()['HostConfig']['Runtime'] == 'runc' + + @v2_3_only() + def test_up_with_invalid_runtime(self): + self.require_api_version('1.30') + config_data = build_config( + version=V2_3, + services=[{ + 'name': 'web', + 'image': 'busybox:latest', + 'runtime': 'foobar' + }], + ) + project = Project.from_config( + client=self.client, + name='composetest', + config_data=config_data + ) + with self.assertRaises(ProjectError): + project.up() + + @v2_3_only() + @if_runtime_available('nvidia') + def test_up_with_nvidia_runtime(self): + self.require_api_version('1.30') + config_data = build_config( + version=V2_3, + services=[{ + 'name': 'web', + 'image': 'busybox:latest', + 'runtime': 'nvidia' + }], + ) + project = Project.from_config( + client=self.client, + name='composetest', + config_data=config_data + ) + project.up(detached=True) + service_container = project.get_service('web').containers(stopped=True)[0] + assert service_container.inspect()['HostConfig']['Runtime'] == 'nvidia' + @v2_only() def test_project_up_with_network_internal(self): self.require_api_version('1.23') diff --git a/tests/integration/testcases.py b/tests/integration/testcases.py index b72fb53a8..84a97b133 100644 --- a/tests/integration/testcases.py +++ b/tests/integration/testcases.py @@ -5,6 +5,7 @@ import functools import os import pytest +import six from docker.errors import APIError from docker.utils import version_lt @@ -155,6 +156,25 @@ class DockerClientTestCase(unittest.TestCase): return self.client.inspect_volume(volumes[0]['Name']) +def if_runtime_available(runtime): + if runtime == 'nvidia': + command = 'nvidia-container-runtime' + if six.PY3: + import shutil + return pytest.mark.skipif( + shutil.which(command) is None, + reason="Nvida runtime not exists" + ) + return pytest.mark.skipif( + any( + os.access(os.path.join(path, command), os.X_OK) + for path in os.environ["PATH"].split(os.pathsep) + ) is False, + reason="Nvida runtime not exists" + ) + return pytest.skip("Runtime %s not exists", runtime) + + def is_cluster(client): if SWARM_ASSUME_MULTINODE: return True diff --git a/tests/unit/config/config_test.py b/tests/unit/config/config_test.py index 00ba6c2c6..fc28f8ef3 100644 --- a/tests/unit/config/config_test.py +++ b/tests/unit/config/config_test.py @@ -1678,6 +1678,25 @@ class ConfigTest(unittest.TestCase): } ] + def test_runtime_option(self): + actual = config.load(build_config_details({ + 'version': str(V2_3), + 'services': { + 'web': { + 'image': 'nvidia/cuda', + 'runtime': 'nvidia' + } + } + })) + + assert actual.services == [ + { + 'name': 'web', + 'image': 'nvidia/cuda', + 'runtime': 'nvidia', + } + ] + def test_merge_service_dicts_from_files_with_extends_in_base(self): base = { 'volumes': ['.:/app'],