From c956785cdc242f081043f3879047c428f0bd893a Mon Sep 17 00:00:00 2001 From: Joffrey F Date: Fri, 20 Jul 2018 15:37:15 -0700 Subject: [PATCH] Add progress messages to parallel pull Signed-off-by: Joffrey F --- compose/parallel.py | 7 ++++++ compose/progress_stream.py | 5 +--- compose/project.py | 25 +++++++++++++++++-- compose/service.py | 39 ++++++++++++++++++------------ tests/integration/testcases.py | 4 ++- tests/unit/progress_stream_test.py | 12 ++++----- 6 files changed, 63 insertions(+), 29 deletions(-) diff --git a/compose/parallel.py b/compose/parallel.py index a2eb160e5..34a498ca7 100644 --- a/compose/parallel.py +++ b/compose/parallel.py @@ -313,6 +313,13 @@ class ParallelStreamWriter(object): self._write_ansi(msg, obj_index, color_func(status)) +def get_stream_writer(): + instance = ParallelStreamWriter.instance + if instance is None: + raise RuntimeError('ParallelStreamWriter has not yet been instantiated') + return instance + + def parallel_operation(containers, operation, options, message): parallel_execute( containers, diff --git a/compose/progress_stream.py b/compose/progress_stream.py index 5e709770a..4cd311432 100644 --- a/compose/progress_stream.py +++ b/compose/progress_stream.py @@ -19,12 +19,11 @@ def write_to_stream(s, stream): def stream_output(output, stream): is_terminal = hasattr(stream, 'isatty') and stream.isatty() stream = utils.get_output_stream(stream) - all_events = [] lines = {} diff = 0 for event in utils.json_stream(output): - all_events.append(event) + yield event is_progress_event = 'progress' in event or 'progressDetail' in event if not is_progress_event: @@ -57,8 +56,6 @@ def stream_output(output, stream): stream.flush() - return all_events - def print_output_event(event, stream, is_terminal): if 'errorDetail' in event: diff --git a/compose/project.py b/compose/project.py index 005b7e240..391bbd038 100644 --- a/compose/project.py +++ b/compose/project.py @@ -548,16 +548,37 @@ class Project(object): def pull(self, service_names=None, ignore_pull_failures=False, parallel_pull=False, silent=False, include_deps=False): services = self.get_services(service_names, include_deps) + msg = not silent and 'Pulling' or None if parallel_pull: def pull_service(service): - service.pull(ignore_pull_failures, True) + strm = service.pull(ignore_pull_failures, True, stream=True) + writer = parallel.get_stream_writer() + + def trunc(s): + if len(s) > 35: + return s[:33] + '...' + return s + + for event in strm: + if 'status' not in event: + continue + status = event['status'].lower() + if 'progressDetail' in event: + detail = event['progressDetail'] + if 'current' in detail and 'total' in detail: + percentage = float(detail['current']) / float(detail['total']) + status = '{} ({:.1%})'.format(status, percentage) + + writer.write( + msg, service.name, trunc(status), lambda s: s + ) _, errors = parallel.parallel_execute( services, pull_service, operator.attrgetter('name'), - not silent and 'Pulling' or None, + msg, limit=5, ) if len(errors): diff --git a/compose/service.py b/compose/service.py index e77780fd8..4b545ab02 100644 --- a/compose/service.py +++ b/compose/service.py @@ -1068,7 +1068,7 @@ class Service(object): ) try: - all_events = stream_output(build_output, sys.stdout) + all_events = list(stream_output(build_output, sys.stdout)) except StreamOutputError as e: raise BuildError(self, six.text_type(e)) @@ -1162,7 +1162,23 @@ class Service(object): return any(has_host_port(binding) for binding in self.options.get('ports', [])) - def pull(self, ignore_pull_failures=False, silent=False): + def _do_pull(self, repo, pull_kwargs, silent, ignore_pull_failures): + try: + output = self.client.pull(repo, **pull_kwargs) + if silent: + with open(os.devnull, 'w') as devnull: + for event in stream_output(output, devnull): + yield event + else: + for event in stream_output(output, sys.stdout): + yield event + except (StreamOutputError, NotFound) as e: + if not ignore_pull_failures: + raise + else: + log.error(six.text_type(e)) + + def pull(self, ignore_pull_failures=False, silent=False, stream=False): if 'image' not in self.options: return @@ -1179,20 +1195,11 @@ class Service(object): raise OperationFailedError( 'Impossible to perform platform-targeted pulls for API version < 1.35' ) - try: - output = self.client.pull(repo, **kwargs) - if silent: - with open(os.devnull, 'w') as devnull: - return progress_stream.get_digest_from_pull( - stream_output(output, devnull)) - else: - return progress_stream.get_digest_from_pull( - stream_output(output, sys.stdout)) - except (StreamOutputError, NotFound) as e: - if not ignore_pull_failures: - raise - else: - log.error(six.text_type(e)) + + event_stream = self._do_pull(repo, kwargs, silent, ignore_pull_failures) + if stream: + return event_stream + return progress_stream.get_digest_from_pull(event_stream) def push(self, ignore_push_failures=False): if 'image' not in self.options or 'build' not in self.options: diff --git a/tests/integration/testcases.py b/tests/integration/testcases.py index 4440d771e..cfdf22f7e 100644 --- a/tests/integration/testcases.py +++ b/tests/integration/testcases.py @@ -139,7 +139,9 @@ class DockerClientTestCase(unittest.TestCase): def check_build(self, *args, **kwargs): kwargs.setdefault('rm', True) build_output = self.client.build(*args, **kwargs) - stream_output(build_output, open('/dev/null', 'w')) + with open(os.devnull, 'w') as devnull: + for event in stream_output(build_output, devnull): + pass def require_api_version(self, minimum): api_version = self.client.version()['ApiVersion'] diff --git a/tests/unit/progress_stream_test.py b/tests/unit/progress_stream_test.py index f4a0ab063..d29227458 100644 --- a/tests/unit/progress_stream_test.py +++ b/tests/unit/progress_stream_test.py @@ -21,7 +21,7 @@ class ProgressStreamTestCase(unittest.TestCase): b'31019763, "start": 1413653874, "total": 62763875}, ' b'"progress": "..."}', ] - events = progress_stream.stream_output(output, StringIO()) + events = list(progress_stream.stream_output(output, StringIO())) assert len(events) == 1 def test_stream_output_div_zero(self): @@ -30,7 +30,7 @@ class ProgressStreamTestCase(unittest.TestCase): b'0, "start": 1413653874, "total": 0}, ' b'"progress": "..."}', ] - events = progress_stream.stream_output(output, StringIO()) + events = list(progress_stream.stream_output(output, StringIO())) assert len(events) == 1 def test_stream_output_null_total(self): @@ -39,7 +39,7 @@ class ProgressStreamTestCase(unittest.TestCase): b'0, "start": 1413653874, "total": null}, ' b'"progress": "..."}', ] - events = progress_stream.stream_output(output, StringIO()) + events = list(progress_stream.stream_output(output, StringIO())) assert len(events) == 1 def test_stream_output_progress_event_tty(self): @@ -52,7 +52,7 @@ class ProgressStreamTestCase(unittest.TestCase): return True output = TTYStringIO() - events = progress_stream.stream_output(events, output) + events = list(progress_stream.stream_output(events, output)) assert len(output.getvalue()) > 0 def test_stream_output_progress_event_no_tty(self): @@ -61,7 +61,7 @@ class ProgressStreamTestCase(unittest.TestCase): ] output = StringIO() - events = progress_stream.stream_output(events, output) + events = list(progress_stream.stream_output(events, output)) assert len(output.getvalue()) == 0 def test_stream_output_no_progress_event_no_tty(self): @@ -70,7 +70,7 @@ class ProgressStreamTestCase(unittest.TestCase): ] output = StringIO() - events = progress_stream.stream_output(events, output) + events = list(progress_stream.stream_output(events, output)) assert len(output.getvalue()) > 0 def test_mismatched_encoding_stream_write(self):