From 854c003359bd07d0d3ca137d7a08509cfeab0436 Mon Sep 17 00:00:00 2001 From: aiordache Date: Fri, 23 Oct 2020 11:45:40 +0200 Subject: [PATCH] Implement device requests for GPU support Signed-off-by: aiordache --- compose/config/compose_spec.json | 20 +++++++++++++++++++- compose/project.py | 28 +++++++++++++++++++++++++++- compose/service.py | 4 ++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/compose/config/compose_spec.json b/compose/config/compose_spec.json index 268256744..0ecb3c69e 100644 --- a/compose/config/compose_spec.json +++ b/compose/config/compose_spec.json @@ -524,7 +524,8 @@ "properties": { "cpus": {"type": ["number", "string"]}, "memory": {"type": "string"}, - "generic_resources": {"$ref": "#/definitions/generic_resources"} + "generic_resources": {"$ref": "#/definitions/generic_resources"}, + "devices": {"$ref": "#/definitions/devices"} }, "additionalProperties": false, "patternProperties": {"^x-": {}} @@ -590,6 +591,23 @@ } }, + "devices": { + "id": "#/definitions/devices", + "type": "array", + "items": { + "type": "object", + "properties": { + "capabilities": {"$ref": "#/definitions/list_of_strings"}, + "count": {"type": ["string", "integer"]}, + "device_ids": {"$ref": "#/definitions/list_of_strings"}, + "driver":{"type": "string"}, + "options":{"$ref": "#/definitions/list_or_dict"} + }, + "additionalProperties": false, + "patternProperties": {"^x-": {}} + } + }, + "network": { "id": "#/definitions/network", "type": ["object", "null"], diff --git a/compose/project.py b/compose/project.py index 420cb6548..900487d4f 100644 --- a/compose/project.py +++ b/compose/project.py @@ -128,7 +128,7 @@ class Project: config_data.secrets) service_dict['scale'] = project.get_service_scale(service_dict) - + device_requests = project.get_device_requests(service_dict) service_dict = translate_credential_spec_to_security_opt(service_dict) service_dict, ignored_keys = translate_deploy_keys_to_container_config( service_dict @@ -154,6 +154,7 @@ class Project: ipc_mode=ipc_mode, platform=service_dict.pop('platform', None), default_platform=default_platform, + device_requests=device_requests, extra_labels=extra_labels, **service_dict) ) @@ -331,6 +332,31 @@ class Project: max_replicas)) return scale + def get_device_requests(self, service_dict): + deploy_dict = service_dict.get('deploy', None) + if not deploy_dict: + return + + resources = deploy_dict.get('resources', None) + if not resources or not resources.get('reservations', None): + return + devices = resources['reservations'].get('devices') + if not devices: + return + + for dev in devices: + count = dev.get("count", -1) + if not isinstance(count, int): + if count != "all": + raise ConfigurationError( + 'Invalid value "{}" for devices count'.format(dev["count"]), + '(expected integer or "all")') + dev["count"] = -1 + + if 'capabilities' in dev: + dev['capabilities'] = [dev['capabilities']] + return devices + def start(self, service_names=None, **options): containers = [] diff --git a/compose/service.py b/compose/service.py index a1a500cb2..e00a537cf 100644 --- a/compose/service.py +++ b/compose/service.py @@ -77,6 +77,7 @@ HOST_CONFIG_KEYS = [ 'cpuset', 'device_cgroup_rules', 'devices', + 'device_requests', 'dns', 'dns_search', 'dns_opt', @@ -180,6 +181,7 @@ class Service: pid_mode=None, default_platform=None, extra_labels=None, + device_requests=None, **options ): self.name = name @@ -195,6 +197,7 @@ class Service: self.secrets = secrets or [] self.scale_num = scale self.default_platform = default_platform + self.device_requests = device_requests self.options = options self.extra_labels = extra_labels or [] @@ -1016,6 +1019,7 @@ class Service: privileged=options.get('privileged', False), network_mode=self.network_mode.mode, devices=options.get('devices'), + device_requests=self.device_requests, dns=options.get('dns'), dns_opt=options.get('dns_opt'), dns_search=options.get('dns_search'),