Add progress messages to parallel pull

Signed-off-by: Joffrey F <joffrey@docker.com>
This commit is contained in:
Joffrey F 2018-07-20 15:37:15 -07:00
parent 6cb17b90ef
commit c956785cdc
6 changed files with 63 additions and 29 deletions

View File

@ -313,6 +313,13 @@ class ParallelStreamWriter(object):
self._write_ansi(msg, obj_index, color_func(status))
def get_stream_writer():
instance = ParallelStreamWriter.instance
if instance is None:
raise RuntimeError('ParallelStreamWriter has not yet been instantiated')
return instance
def parallel_operation(containers, operation, options, message):
parallel_execute(
containers,

View File

@ -19,12 +19,11 @@ def write_to_stream(s, stream):
def stream_output(output, stream):
is_terminal = hasattr(stream, 'isatty') and stream.isatty()
stream = utils.get_output_stream(stream)
all_events = []
lines = {}
diff = 0
for event in utils.json_stream(output):
all_events.append(event)
yield event
is_progress_event = 'progress' in event or 'progressDetail' in event
if not is_progress_event:
@ -57,8 +56,6 @@ def stream_output(output, stream):
stream.flush()
return all_events
def print_output_event(event, stream, is_terminal):
if 'errorDetail' in event:

View File

@ -548,16 +548,37 @@ class Project(object):
def pull(self, service_names=None, ignore_pull_failures=False, parallel_pull=False, silent=False,
include_deps=False):
services = self.get_services(service_names, include_deps)
msg = not silent and 'Pulling' or None
if parallel_pull:
def pull_service(service):
service.pull(ignore_pull_failures, True)
strm = service.pull(ignore_pull_failures, True, stream=True)
writer = parallel.get_stream_writer()
def trunc(s):
if len(s) > 35:
return s[:33] + '...'
return s
for event in strm:
if 'status' not in event:
continue
status = event['status'].lower()
if 'progressDetail' in event:
detail = event['progressDetail']
if 'current' in detail and 'total' in detail:
percentage = float(detail['current']) / float(detail['total'])
status = '{} ({:.1%})'.format(status, percentage)
writer.write(
msg, service.name, trunc(status), lambda s: s
)
_, errors = parallel.parallel_execute(
services,
pull_service,
operator.attrgetter('name'),
not silent and 'Pulling' or None,
msg,
limit=5,
)
if len(errors):

View File

@ -1068,7 +1068,7 @@ class Service(object):
)
try:
all_events = stream_output(build_output, sys.stdout)
all_events = list(stream_output(build_output, sys.stdout))
except StreamOutputError as e:
raise BuildError(self, six.text_type(e))
@ -1162,7 +1162,23 @@ class Service(object):
return any(has_host_port(binding) for binding in self.options.get('ports', []))
def pull(self, ignore_pull_failures=False, silent=False):
def _do_pull(self, repo, pull_kwargs, silent, ignore_pull_failures):
try:
output = self.client.pull(repo, **pull_kwargs)
if silent:
with open(os.devnull, 'w') as devnull:
for event in stream_output(output, devnull):
yield event
else:
for event in stream_output(output, sys.stdout):
yield event
except (StreamOutputError, NotFound) as e:
if not ignore_pull_failures:
raise
else:
log.error(six.text_type(e))
def pull(self, ignore_pull_failures=False, silent=False, stream=False):
if 'image' not in self.options:
return
@ -1179,20 +1195,11 @@ class Service(object):
raise OperationFailedError(
'Impossible to perform platform-targeted pulls for API version < 1.35'
)
try:
output = self.client.pull(repo, **kwargs)
if silent:
with open(os.devnull, 'w') as devnull:
return progress_stream.get_digest_from_pull(
stream_output(output, devnull))
else:
return progress_stream.get_digest_from_pull(
stream_output(output, sys.stdout))
except (StreamOutputError, NotFound) as e:
if not ignore_pull_failures:
raise
else:
log.error(six.text_type(e))
event_stream = self._do_pull(repo, kwargs, silent, ignore_pull_failures)
if stream:
return event_stream
return progress_stream.get_digest_from_pull(event_stream)
def push(self, ignore_push_failures=False):
if 'image' not in self.options or 'build' not in self.options:

View File

@ -139,7 +139,9 @@ class DockerClientTestCase(unittest.TestCase):
def check_build(self, *args, **kwargs):
kwargs.setdefault('rm', True)
build_output = self.client.build(*args, **kwargs)
stream_output(build_output, open('/dev/null', 'w'))
with open(os.devnull, 'w') as devnull:
for event in stream_output(build_output, devnull):
pass
def require_api_version(self, minimum):
api_version = self.client.version()['ApiVersion']

View File

@ -21,7 +21,7 @@ class ProgressStreamTestCase(unittest.TestCase):
b'31019763, "start": 1413653874, "total": 62763875}, '
b'"progress": "..."}',
]
events = progress_stream.stream_output(output, StringIO())
events = list(progress_stream.stream_output(output, StringIO()))
assert len(events) == 1
def test_stream_output_div_zero(self):
@ -30,7 +30,7 @@ class ProgressStreamTestCase(unittest.TestCase):
b'0, "start": 1413653874, "total": 0}, '
b'"progress": "..."}',
]
events = progress_stream.stream_output(output, StringIO())
events = list(progress_stream.stream_output(output, StringIO()))
assert len(events) == 1
def test_stream_output_null_total(self):
@ -39,7 +39,7 @@ class ProgressStreamTestCase(unittest.TestCase):
b'0, "start": 1413653874, "total": null}, '
b'"progress": "..."}',
]
events = progress_stream.stream_output(output, StringIO())
events = list(progress_stream.stream_output(output, StringIO()))
assert len(events) == 1
def test_stream_output_progress_event_tty(self):
@ -52,7 +52,7 @@ class ProgressStreamTestCase(unittest.TestCase):
return True
output = TTYStringIO()
events = progress_stream.stream_output(events, output)
events = list(progress_stream.stream_output(events, output))
assert len(output.getvalue()) > 0
def test_stream_output_progress_event_no_tty(self):
@ -61,7 +61,7 @@ class ProgressStreamTestCase(unittest.TestCase):
]
output = StringIO()
events = progress_stream.stream_output(events, output)
events = list(progress_stream.stream_output(events, output))
assert len(output.getvalue()) == 0
def test_stream_output_no_progress_event_no_tty(self):
@ -70,7 +70,7 @@ class ProgressStreamTestCase(unittest.TestCase):
]
output = StringIO()
events = progress_stream.stream_output(events, output)
events = list(progress_stream.stream_output(events, output))
assert len(output.getvalue()) > 0
def test_mismatched_encoding_stream_write(self):