From 22d90d21800bd5bf5c695f09cd3c98928781db9e Mon Sep 17 00:00:00 2001
From: Kevin Greene <kevin@spantree.net>
Date: Mon, 26 Oct 2015 17:39:50 -0400
Subject: [PATCH] Added ulimits functionality to docker compose

Signed-off-by: Kevin Greene <kevin@spantree.net>
---
 compose/config/config.py          | 12 +++++++
 compose/config/fields_schema.json | 19 ++++++++++
 compose/service.py                | 19 ++++++++++
 docs/compose-file.md              | 11 ++++++
 tests/integration/service_test.py | 31 ++++++++++++++++
 tests/unit/config/config_test.py  | 60 +++++++++++++++++++++++++++++++
 6 files changed, 152 insertions(+)

diff --git a/compose/config/config.py b/compose/config/config.py
index 434589d31..7931608d2 100644
--- a/compose/config/config.py
+++ b/compose/config/config.py
@@ -345,6 +345,15 @@ def validate_extended_service_dict(service_dict, filename, service):
                 "%s services with 'net: container' cannot be extended" % error_prefix)
 
 
+def validate_ulimits(ulimit_config):
+    for limit_name, soft_hard_values in six.iteritems(ulimit_config):
+        if isinstance(soft_hard_values, dict):
+            if not soft_hard_values['soft'] <= soft_hard_values['hard']:
+                raise ConfigurationError(
+                    "ulimit_config \"{}\" cannot contain a 'soft' value higher "
+                    "than 'hard' value".format(ulimit_config))
+
+
 def process_container_options(working_dir, service_dict):
     service_dict = dict(service_dict)
 
@@ -357,6 +366,9 @@ def process_container_options(working_dir, service_dict):
     if 'labels' in service_dict:
         service_dict['labels'] = parse_labels(service_dict['labels'])
 
+    if 'ulimits' in service_dict:
+        validate_ulimits(service_dict['ulimits'])
+
     return service_dict
 
 
diff --git a/compose/config/fields_schema.json b/compose/config/fields_schema.json
index e254e3539..f22b513ae 100644
--- a/compose/config/fields_schema.json
+++ b/compose/config/fields_schema.json
@@ -116,6 +116,25 @@
         "security_opt": {"type": "array", "items": {"type": "string"}, "uniqueItems": true},
         "stdin_open": {"type": "boolean"},
         "tty": {"type": "boolean"},
+        "ulimits": {
+          "type": "object",
+          "patternProperties": {
+            "^[a-z]+$": {
+              "oneOf": [
+                {"type": "integer"},
+                {
+                  "type":"object",
+                  "properties": {
+                    "hard": {"type": "integer"},
+                    "soft": {"type": "integer"}
+                  },
+                  "required": ["soft", "hard"],
+                  "additionalProperties": false
+                }
+              ]
+            }
+          }
+        },
         "user": {"type": "string"},
         "volumes": {"type": "array", "items": {"type": "string"}, "uniqueItems": true},
         "volume_driver": {"type": "string"},
diff --git a/compose/service.py b/compose/service.py
index 2055a6fe1..9e0066b77 100644
--- a/compose/service.py
+++ b/compose/service.py
@@ -676,6 +676,7 @@ class Service(object):
 
         devices = options.get('devices', None)
         cgroup_parent = options.get('cgroup_parent', None)
+        ulimits = build_ulimits(options.get('ulimits', None))
 
         return self.client.create_host_config(
             links=self._get_links(link_to_self=one_off),
@@ -692,6 +693,7 @@ class Service(object):
             cap_drop=cap_drop,
             mem_limit=options.get('mem_limit'),
             memswap_limit=options.get('memswap_limit'),
+            ulimits=ulimits,
             log_config=log_config,
             extra_hosts=extra_hosts,
             read_only=read_only,
@@ -1073,6 +1075,23 @@ def parse_restart_spec(restart_config):
 
     return {'Name': name, 'MaximumRetryCount': int(max_retry_count)}
 
+# Ulimits
+
+
+def build_ulimits(ulimit_config):
+    if not ulimit_config:
+        return None
+    ulimits = []
+    for limit_name, soft_hard_values in six.iteritems(ulimit_config):
+        if isinstance(soft_hard_values, six.integer_types):
+            ulimits.append({'name': limit_name, 'soft': soft_hard_values, 'hard': soft_hard_values})
+        elif isinstance(soft_hard_values, dict):
+            ulimit_dict = {'name': limit_name}
+            ulimit_dict.update(soft_hard_values)
+            ulimits.append(ulimit_dict)
+
+    return ulimits
+
 
 # Extra hosts
 
diff --git a/docs/compose-file.md b/docs/compose-file.md
index 4f8fc9e01..3b36fa2bd 100644
--- a/docs/compose-file.md
+++ b/docs/compose-file.md
@@ -333,6 +333,17 @@ Override the default labeling scheme for each container.
         - label:user:USER
         - label:role:ROLE
 
+### ulimits
+
+Override the default ulimits for a container. You can either use a number
+to set the hard and soft limits, or specify them in a dictionary.
+
+      ulimits:
+        nproc: 65535
+        nofile:
+          soft: 20000
+          hard: 40000
+
 ### volumes, volume\_driver
 
 Mount paths as volumes, optionally specifying a path on the host machine
diff --git a/tests/integration/service_test.py b/tests/integration/service_test.py
index 804f5219a..2f3be89a3 100644
--- a/tests/integration/service_test.py
+++ b/tests/integration/service_test.py
@@ -22,6 +22,7 @@ from compose.const import LABEL_SERVICE
 from compose.const import LABEL_VERSION
 from compose.container import Container
 from compose.service import build_extra_hosts
+from compose.service import build_ulimits
 from compose.service import ConfigError
 from compose.service import ConvergencePlan
 from compose.service import ConvergenceStrategy
@@ -164,6 +165,36 @@ class ServiceTest(DockerClientTestCase):
             {'www.example.com': '192.168.0.17',
              'api.example.com': '192.168.0.18'})
 
+    def sort_dicts_by_name(self, dictionary_list):
+        return sorted(dictionary_list, key=lambda k: k['name'])
+
+    def test_build_ulimits_with_invalid_options(self):
+        self.assertRaises(ConfigError, lambda: build_ulimits({'nofile': {'soft': 10000, 'hard': 10}}))
+
+    def test_build_ulimits_with_integers(self):
+        self.assertEqual(build_ulimits(
+            {'nofile': {'soft': 10000, 'hard': 20000}}),
+            [{'name': 'nofile', 'soft': 10000, 'hard': 20000}])
+        self.assertEqual(self.sort_dicts_by_name(build_ulimits(
+            {'nofile': {'soft': 10000, 'hard': 20000}, 'nproc': {'soft': 65535, 'hard': 65535}})),
+            self.sort_dicts_by_name([{'name': 'nofile', 'soft': 10000, 'hard': 20000},
+                                     {'name': 'nproc', 'soft': 65535, 'hard': 65535}]))
+
+    def test_build_ulimits_with_dicts(self):
+        self.assertEqual(build_ulimits(
+            {'nofile': 20000}),
+            [{'name': 'nofile', 'soft': 20000, 'hard': 20000}])
+        self.assertEqual(self.sort_dicts_by_name(build_ulimits(
+            {'nofile': 20000, 'nproc': 65535})),
+            self.sort_dicts_by_name([{'name': 'nofile', 'soft': 20000, 'hard': 20000},
+                                     {'name': 'nproc', 'soft': 65535, 'hard': 65535}]))
+
+    def test_build_ulimits_with_integers_and_dicts(self):
+        self.assertEqual(self.sort_dicts_by_name(build_ulimits(
+            {'nproc': 65535, 'nofile': {'soft': 10000, 'hard': 20000}})),
+            self.sort_dicts_by_name([{'name': 'nofile', 'soft': 10000, 'hard': 20000},
+                                     {'name': 'nproc', 'soft': 65535, 'hard': 65535}]))
+
     def test_create_container_with_extra_hosts_list(self):
         extra_hosts = ['somehost:162.242.195.82', 'otherhost:50.31.209.229']
         service = self.create_service('db', extra_hosts=extra_hosts)
diff --git a/tests/unit/config/config_test.py b/tests/unit/config/config_test.py
index e0d2e870b..f27329ba0 100644
--- a/tests/unit/config/config_test.py
+++ b/tests/unit/config/config_test.py
@@ -349,6 +349,66 @@ class ConfigTest(unittest.TestCase):
                 )
             )
 
+    def test_config_ulimits_invalid_keys_validation_error(self):
+        expected_error_msg = "Service 'web' configuration key 'ulimits' contains unsupported option: 'not_soft_or_hard'"
+
+        with self.assertRaisesRegexp(ConfigurationError, expected_error_msg):
+            config.load(
+                build_config_details(
+                    {'web': {
+                        'image': 'busybox',
+                        'ulimits': {
+                            'nofile': {
+                                "not_soft_or_hard": 100,
+                                "soft": 10000,
+                                "hard": 20000,
+                            }
+                        }
+                    }},
+                    'working_dir',
+                    'filename.yml'
+                )
+            )
+
+    def test_config_ulimits_required_keys_validation_error(self):
+        expected_error_msg = "Service 'web' configuration key 'ulimits' u?'hard' is a required property"
+
+        with self.assertRaisesRegexp(ConfigurationError, expected_error_msg):
+            config.load(
+                build_config_details(
+                    {'web': {
+                        'image': 'busybox',
+                        'ulimits': {
+                            'nofile': {
+                                "soft": 10000,
+                            }
+                        }
+                    }},
+                    'working_dir',
+                    'filename.yml'
+                )
+            )
+
+    def test_config_ulimits_soft_greater_than_hard_error(self):
+        expected_error_msg = "cannot contain a 'soft' value higher than 'hard' value"
+
+        with self.assertRaisesRegexp(ConfigurationError, expected_error_msg):
+            config.load(
+                build_config_details(
+                    {'web': {
+                        'image': 'busybox',
+                        'ulimits': {
+                            'nofile': {
+                                "soft": 10000,
+                                "hard": 1000
+                            }
+                        }
+                    }},
+                    'working_dir',
+                    'filename.yml'
+                )
+            )
+
     def test_valid_config_which_allows_two_type_definitions(self):
         expose_values = [["8000"], [8000]]
         for expose in expose_values: