From 7210fdb21cf1875161005938cd1ba4c33dbc0f2e Mon Sep 17 00:00:00 2001
From: Joffrey F <joffrey@docker.com>
Date: Fri, 11 Aug 2017 15:52:43 -0700
Subject: [PATCH] Add support for start_period in healthcheck config

Improve merging strategy for healthcheck configs

Signed-off-by: Joffrey F <joffrey@docker.com>
---
 compose/config/config.py               | 25 +++++----
 compose/config/config_schema_v2.3.json |  1 +
 compose/config/serialize.py            |  4 ++
 compose/utils.py                       |  5 +-
 tests/integration/service_test.py      | 19 +++++++
 tests/unit/config/config_test.py       | 77 +++++++++++++++++++++++++-
 6 files changed, 117 insertions(+), 14 deletions(-)

diff --git a/compose/config/config.py b/compose/config/config.py
index aa829a40e..0c2ab1ab7 100644
--- a/compose/config/config.py
+++ b/compose/config/config.py
@@ -797,16 +797,12 @@ def process_healthcheck(service_dict, service_name):
     elif 'test' in raw:
         hc['test'] = raw['test']
 
-    if 'interval' in raw:
-        if not isinstance(raw['interval'], six.integer_types):
-            hc['interval'] = parse_nanoseconds_int(raw['interval'])
-        else:  # Conversion has been done previously
-            hc['interval'] = raw['interval']
-    if 'timeout' in raw:
-        if not isinstance(raw['timeout'], six.integer_types):
-            hc['timeout'] = parse_nanoseconds_int(raw['timeout'])
-        else:  # Conversion has been done previously
-            hc['timeout'] = raw['timeout']
+    for field in ['interval', 'timeout', 'start_period']:
+        if field in raw:
+            if not isinstance(raw[field], six.integer_types):
+                hc[field] = parse_nanoseconds_int(raw[field])
+            else:  # Conversion has been done previously
+                hc[field] = raw[field]
     if 'retries' in raw:
         hc['retries'] = raw['retries']
 
@@ -967,6 +963,7 @@ def merge_service_dicts(base, override, version):
     md.merge_field('logging', merge_logging, default={})
     merge_ports(md, base, override)
     md.merge_field('blkio_config', merge_blkio_config, default={})
+    md.merge_field('healthcheck', merge_healthchecks, default={})
 
     for field in set(ALLOWED_KEYS) - set(md):
         md.merge_scalar(field)
@@ -985,6 +982,14 @@ def merge_unique_items_lists(base, override):
     return sorted(set().union(base, override))
 
 
+def merge_healthchecks(base, override):
+    if override.get('disabled') is True:
+        return override
+    result = base.copy()
+    result.update(override)
+    return result
+
+
 def merge_ports(md, base, override):
     def parse_sequence_func(seq):
         acc = []
diff --git a/compose/config/config_schema_v2.3.json b/compose/config/config_schema_v2.3.json
index 789adf4ab..7a9bdfdf1 100644
--- a/compose/config/config_schema_v2.3.json
+++ b/compose/config/config_schema_v2.3.json
@@ -309,6 +309,7 @@
         "disable": {"type": "boolean"},
         "interval": {"type": "string"},
         "retries": {"type": "number"},
+        "start_period": {"type": "string"},
         "test": {
           "oneOf": [
             {"type": "string"},
diff --git a/compose/config/serialize.py b/compose/config/serialize.py
index 0f0cb7f50..daddff695 100644
--- a/compose/config/serialize.py
+++ b/compose/config/serialize.py
@@ -131,6 +131,10 @@ def denormalize_service_dict(service_dict, version, image_digest=None):
                 service_dict['healthcheck']['timeout']
             )
 
+        if 'start_period' in service_dict['healthcheck']:
+            service_dict['healthcheck']['start_period'] = serialize_ns_time_value(
+                service_dict['healthcheck']['start_period']
+            )
     if 'ports' in service_dict and version < V3_2:
         service_dict['ports'] = [
             p.legacy_repr() if isinstance(p, types.ServicePort) else p
diff --git a/compose/utils.py b/compose/utils.py
index 183a4504d..1ede4d37d 100644
--- a/compose/utils.py
+++ b/compose/utils.py
@@ -14,6 +14,7 @@ from docker.utils import parse_bytes as sdk_parse_bytes
 
 from .config.errors import ConfigurationError
 from .errors import StreamParseError
+from .timeparse import MULTIPLIERS
 from .timeparse import timeparse
 
 
@@ -112,7 +113,7 @@ def microseconds_from_time_nano(time_nano):
 
 
 def nanoseconds_from_time_seconds(time_seconds):
-    return time_seconds * 1000000000
+    return int(time_seconds / MULTIPLIERS['nano'])
 
 
 def parse_seconds_float(value):
@@ -123,7 +124,7 @@ def parse_nanoseconds_int(value):
     parsed = timeparse(value or '')
     if parsed is None:
         return None
-    return int(parsed * 1000000000)
+    return nanoseconds_from_time_seconds(parsed)
 
 
 def build_string_dict(source_dict):
diff --git a/tests/integration/service_test.py b/tests/integration/service_test.py
index 2abb12c34..84b54fe41 100644
--- a/tests/integration/service_test.py
+++ b/tests/integration/service_test.py
@@ -36,6 +36,7 @@ from compose.service import ConvergenceStrategy
 from compose.service import NetworkMode
 from compose.service import PidMode
 from compose.service import Service
+from compose.utils import parse_nanoseconds_int
 from tests.integration.testcases import is_cluster
 from tests.integration.testcases import no_cluster
 from tests.integration.testcases import v2_1_only
@@ -270,6 +271,24 @@ class ServiceTest(DockerClientTestCase):
         self.assertTrue(path.basename(actual_host_path) == path.basename(host_path),
                         msg=("Last component differs: %s, %s" % (actual_host_path, host_path)))
 
+    def test_create_container_with_healthcheck_config(self):
+        one_second = parse_nanoseconds_int('1s')
+        healthcheck = {
+            'test': ['true'],
+            'interval': 2 * one_second,
+            'timeout': 5 * one_second,
+            'retries': 5,
+            'start_period': 2 * one_second
+        }
+        service = self.create_service('db', healthcheck=healthcheck)
+        container = service.create_container()
+        remote_healthcheck = container.get('Config.Healthcheck')
+        assert remote_healthcheck['Test'] == healthcheck['test']
+        assert remote_healthcheck['Interval'] == healthcheck['interval']
+        assert remote_healthcheck['Timeout'] == healthcheck['timeout']
+        assert remote_healthcheck['Retries'] == healthcheck['retries']
+        assert remote_healthcheck['StartPeriod'] == healthcheck['start_period']
+
     def test_recreate_preserves_volume_with_trailing_slash(self):
         """When the Compose file specifies a trailing slash in the container path, make
         sure we copy the volume over when recreating.
diff --git a/tests/unit/config/config_test.py b/tests/unit/config/config_test.py
index 8a1e16f8a..4e355d3bf 100644
--- a/tests/unit/config/config_test.py
+++ b/tests/unit/config/config_test.py
@@ -2197,6 +2197,75 @@ class ConfigTest(unittest.TestCase):
             }
         }
 
+    def test_merge_healthcheck_config(self):
+        base = {
+            'image': 'bar',
+            'healthcheck': {
+                'start_period': 1000,
+                'interval': 3000,
+                'test': ['true']
+            }
+        }
+
+        override = {
+            'healthcheck': {
+                'interval': 5000,
+                'timeout': 10000,
+                'test': ['echo', 'OK'],
+            }
+        }
+
+        actual = config.merge_service_dicts(base, override, V2_3)
+        assert actual['healthcheck'] == {
+            'start_period': base['healthcheck']['start_period'],
+            'test': override['healthcheck']['test'],
+            'interval': override['healthcheck']['interval'],
+            'timeout': override['healthcheck']['timeout'],
+        }
+
+    def test_merge_healthcheck_override_disables(self):
+        base = {
+            'image': 'bar',
+            'healthcheck': {
+                'start_period': 1000,
+                'interval': 3000,
+                'timeout': 2000,
+                'retries': 3,
+                'test': ['true']
+            }
+        }
+
+        override = {
+            'healthcheck': {
+                'disabled': True
+            }
+        }
+
+        actual = config.merge_service_dicts(base, override, V2_3)
+        assert actual['healthcheck'] == {'disabled': True}
+
+    def test_merge_healthcheck_override_enables(self):
+        base = {
+            'image': 'bar',
+            'healthcheck': {
+                'disabled': True
+            }
+        }
+
+        override = {
+            'healthcheck': {
+                'disabled': False,
+                'start_period': 1000,
+                'interval': 3000,
+                'timeout': 2000,
+                'retries': 3,
+                'test': ['true']
+            }
+        }
+
+        actual = config.merge_service_dicts(base, override, V2_3)
+        assert actual['healthcheck'] == override['healthcheck']
+
     def test_external_volume_config(self):
         config_details = build_config_details({
             'version': '2',
@@ -4008,6 +4077,7 @@ class HealthcheckTest(unittest.TestCase):
                 'interval': '1s',
                 'timeout': '1m',
                 'retries': 3,
+                'start_period': '10s'
             }},
             '.',
         )
@@ -4017,6 +4087,7 @@ class HealthcheckTest(unittest.TestCase):
             'interval': nanoseconds_from_time_seconds(1),
             'timeout': nanoseconds_from_time_seconds(60),
             'retries': 3,
+            'start_period': nanoseconds_from_time_seconds(10)
         }
 
     def test_disable(self):
@@ -4147,15 +4218,17 @@ class SerializeTest(unittest.TestCase):
                 'test': 'exit 1',
                 'interval': '1m40s',
                 'timeout': '30s',
-                'retries': 5
+                'retries': 5,
+                'start_period': '2s90ms'
             }
         }
         processed_service = config.process_service(config.ServiceConfig(
             '.', 'test', 'test', service_dict
         ))
-        denormalized_service = denormalize_service_dict(processed_service, V2_1)
+        denormalized_service = denormalize_service_dict(processed_service, V2_3)
         assert denormalized_service['healthcheck']['interval'] == '100s'
         assert denormalized_service['healthcheck']['timeout'] == '30s'
+        assert denormalized_service['healthcheck']['start_period'] == '2090ms'
 
     def test_denormalize_image_has_digest(self):
         service_dict = {