upstream commit

switch sshconnect.c from (slightly abused) select() to
poll(); ok deraadt@ a while back

Upstream-ID: efc1937fc591bbe70ac9e9542bb984f354c8c175
This commit is contained in:
djm@openbsd.org 2017-06-24 05:37:44 +00:00 committed by Damien Miller
parent 6f8ca3b925
commit 4540428cd0
1 changed files with 65 additions and 93 deletions

View File

@ -1,4 +1,4 @@
/* $OpenBSD: sshconnect.c,v 1.281 2017/06/24 05:35:05 djm Exp $ */ /* $OpenBSD: sshconnect.c,v 1.282 2017/06/24 05:37:44 djm Exp $ */
/* /*
* Author: Tatu Ylonen <ylo@cs.hut.fi> * Author: Tatu Ylonen <ylo@cs.hut.fi>
* Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@ -34,6 +34,9 @@
#include <paths.h> #include <paths.h>
#endif #endif
#include <pwd.h> #include <pwd.h>
#ifdef HAVE_POLL_H
#include <poll.h>
#endif
#include <signal.h> #include <signal.h>
#include <stdarg.h> #include <stdarg.h>
#include <stdio.h> #include <stdio.h>
@ -328,87 +331,71 @@ ssh_create_socket(int privileged, struct addrinfo *ai)
return sock; return sock;
} }
/*
* Wait up to *timeoutp milliseconds for fd to be readable. Updates
* *timeoutp with time remaining.
* Returns 0 if fd ready or -1 on timeout or error (see errno).
*/
static int
waitrfd(int fd, int *timeoutp)
{
struct pollfd pfd;
struct timeval t_start;
int oerrno, r;
gettimeofday(&t_start, NULL);
pfd.fd = fd;
pfd.events = POLLIN;
for (; *timeoutp >= 0;) {
r = poll(&pfd, 1, *timeoutp);
oerrno = errno;
ms_subtract_diff(&t_start, timeoutp);
errno = oerrno;
if (r > 0)
return 0;
else if (r == -1 && errno != EAGAIN)
return -1;
else if (r == 0)
break;
}
/* timeout */
errno = ETIMEDOUT;
return -1;
}
static int static int
timeout_connect(int sockfd, const struct sockaddr *serv_addr, timeout_connect(int sockfd, const struct sockaddr *serv_addr,
socklen_t addrlen, int *timeoutp) socklen_t addrlen, int *timeoutp)
{ {
fd_set *fdset; int optval = 0;
struct timeval tv, t_start; socklen_t optlen = sizeof(optval);
socklen_t optlen;
int optval, rc, result = -1;
gettimeofday(&t_start, NULL); /* No timeout: just do a blocking connect() */
if (*timeoutp <= 0)
if (*timeoutp <= 0) { return connect(sockfd, serv_addr, addrlen);
result = connect(sockfd, serv_addr, addrlen);
goto done;
}
set_nonblock(sockfd); set_nonblock(sockfd);
rc = connect(sockfd, serv_addr, addrlen); if (connect(sockfd, serv_addr, addrlen) == 0) {
if (rc == 0) { /* Succeeded already? */
unset_nonblock(sockfd); unset_nonblock(sockfd);
result = 0; return 0;
goto done; } else if (errno != EINPROGRESS)
return -1;
if (waitrfd(sockfd, timeoutp) == -1)
return -1;
/* Completed or failed */
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval, &optlen) == -1) {
debug("getsockopt: %s", strerror(errno));
return -1;
} }
if (errno != EINPROGRESS) { if (optval != 0) {
result = -1; errno = optval;
goto done; return -1;
} }
unset_nonblock(sockfd);
fdset = xcalloc(howmany(sockfd + 1, NFDBITS), return 0;
sizeof(fd_mask));
FD_SET(sockfd, fdset);
ms_to_timeval(&tv, *timeoutp);
for (;;) {
rc = select(sockfd + 1, NULL, fdset, NULL, &tv);
if (rc != -1 || errno != EINTR)
break;
}
switch (rc) {
case 0:
/* Timed out */
errno = ETIMEDOUT;
break;
case -1:
/* Select error */
debug("select: %s", strerror(errno));
break;
case 1:
/* Completed or failed */
optval = 0;
optlen = sizeof(optval);
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval,
&optlen) == -1) {
debug("getsockopt: %s", strerror(errno));
break;
}
if (optval != 0) {
errno = optval;
break;
}
result = 0;
unset_nonblock(sockfd);
break;
default:
/* Should not occur */
fatal("Bogus return (%d) from select()", rc);
}
free(fdset);
done:
if (result == 0 && *timeoutp > 0) {
ms_subtract_diff(&t_start, timeoutp);
if (*timeoutp <= 0) {
errno = ETIMEDOUT;
result = -1;
}
}
return (result);
} }
/* /*
@ -546,39 +533,25 @@ ssh_exchange_identification(int timeout_ms)
int connection_out = packet_get_connection_out(); int connection_out = packet_get_connection_out();
u_int i, n; u_int i, n;
size_t len; size_t len;
int fdsetsz, remaining, rc; int rc;
struct timeval t_start, t_remaining;
fd_set *fdset;
fdsetsz = howmany(connection_in + 1, NFDBITS) * sizeof(fd_mask);
fdset = xcalloc(1, fdsetsz);
send_client_banner(connection_out, 0); send_client_banner(connection_out, 0);
/* Read other side's version identification. */ /* Read other side's version identification. */
remaining = timeout_ms;
for (n = 0;;) { for (n = 0;;) {
for (i = 0; i < sizeof(buf) - 1; i++) { for (i = 0; i < sizeof(buf) - 1; i++) {
if (timeout_ms > 0) { if (timeout_ms > 0) {
gettimeofday(&t_start, NULL); rc = waitrfd(connection_in, &timeout_ms);
ms_to_timeval(&t_remaining, remaining); if (rc == -1 && errno == ETIMEDOUT) {
FD_SET(connection_in, fdset);
rc = select(connection_in + 1, fdset, NULL,
fdset, &t_remaining);
ms_subtract_diff(&t_start, &remaining);
if (rc == 0 || remaining <= 0)
fatal("Connection timed out during " fatal("Connection timed out during "
"banner exchange"); "banner exchange");
if (rc == -1) { } else if (rc == -1) {
if (errno == EINTR) fatal("%s: %s",
continue; __func__, strerror(errno));
fatal("ssh_exchange_identification: "
"select: %s", strerror(errno));
} }
} }
len = atomicio(read, connection_in, &buf[i], 1); len = atomicio(read, connection_in, &buf[i], 1);
if (len != 1 && errno == EPIPE) if (len != 1 && errno == EPIPE)
fatal("ssh_exchange_identification: " fatal("ssh_exchange_identification: "
"Connection closed by remote host"); "Connection closed by remote host");
@ -604,7 +577,6 @@ ssh_exchange_identification(int timeout_ms)
debug("ssh_exchange_identification: %s", buf); debug("ssh_exchange_identification: %s", buf);
} }
server_version_string = xstrdup(buf); server_version_string = xstrdup(buf);
free(fdset);
/* /*
* Check that the versions match. In future this might accept * Check that the versions match. In future this might accept