diff --git a/compose/config/config.py b/compose/config/config.py index 1e793d9f6..6cffa2fe8 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -4,7 +4,7 @@ import sys import yaml from collections import namedtuple import json -import jsonschema +from jsonschema import Draft4Validator, FormatChecker, ValidationError import six @@ -133,6 +133,28 @@ def get_config_path(base_dir): return os.path.join(path, winner) +@FormatChecker.cls_checks(format="ports", raises=ValidationError("Ports is incorrectly formatted.")) +def format_ports(instance): + def _is_valid(port): + if ':' in port or '/' in port: + return True + try: + int(port) + return True + except ValueError: + return False + return False + + if isinstance(instance, list): + for port in instance: + if not _is_valid(port): + return False + return True + elif isinstance(instance, str): + return _is_valid(instance) + return False + + def validate_against_schema(config): config_source_dir = os.path.dirname(os.path.abspath(__file__)) schema_file = os.path.join(config_source_dir, "schema.json") @@ -140,7 +162,7 @@ def validate_against_schema(config): with open(schema_file, "r") as schema_fh: schema = json.load(schema_fh) - validation_output = jsonschema.Draft4Validator(schema) + validation_output = Draft4Validator(schema, format_checker=FormatChecker(["ports"])) errors = [error.message for error in sorted(validation_output.iter_errors(config), key=str)] if errors: diff --git a/compose/schema.json b/compose/schema.json index 7c7e2d096..bf43ca36b 100644 --- a/compose/schema.json +++ b/compose/schema.json @@ -14,6 +14,17 @@ "type": "object", "properties": { + "ports": { + "oneOf": [ + {"type": "string", "format": "ports"}, + { + "type": "array", + "items": {"type": "string"}, + "uniqueItems": true, + "format": "ports" + } + ] + }, "build": {"type": "string"}, "env_file": {"$ref": "#/definitions/string_or_list"}, "environment": { diff --git a/tests/unit/config_test.py b/tests/unit/config_test.py index f06cbab63..f7e949d3c 100644 --- a/tests/unit/config_test.py +++ b/tests/unit/config_test.py @@ -80,6 +80,28 @@ class ConfigTest(unittest.TestCase): ) ) + def test_config_invalid_ports_format_validation(self): + with self.assertRaises(config.ConfigurationError): + for invalid_ports in [{"1": "8000"}, "whatport"]: + config.load( + config.ConfigDetails( + {'web': {'image': 'busybox', 'ports': invalid_ports}}, + 'working_dir', + 'filename.yml' + ) + ) + + def test_config_valid_ports_format_validation(self): + valid_ports = [["8000", "9000"], "625", "8000:8050", ["8000/8050"]] + for ports in valid_ports: + config.load( + config.ConfigDetails( + {'web': {'image': 'busybox', 'ports': ports}}, + 'working_dir', + 'filename.yml' + ) + ) + class InterpolationTest(unittest.TestCase): @mock.patch.dict(os.environ)