upstream commit

switch from select() to poll() for the ssh-agent
mainloop; ok markus

Upstream-ID: 4a94888ee67b3fd948fd10693973beb12f802448
This commit is contained in:
djm@openbsd.org 2017-07-19 01:15:02 +00:00 committed by Damien Miller
parent b1e72df2b8
commit fd0e8fa5f8
1 changed files with 185 additions and 127 deletions

View File

@ -1,4 +1,4 @@
/* $OpenBSD: ssh-agent.c,v 1.222 2017/07/01 13:50:45 djm Exp $ */
/* $OpenBSD: ssh-agent.c,v 1.223 2017/07/19 01:15:02 djm Exp $ */
/*
* Author: Tatu Ylonen <ylo@cs.hut.fi>
* Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@ -60,6 +60,9 @@
#ifdef HAVE_PATHS_H
# include <paths.h>
#endif
#ifdef HAVE_POLL_H
# include <poll.h>
#endif
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
@ -91,6 +94,9 @@
# define DEFAULT_PKCS11_WHITELIST "/usr/lib*/*,/usr/local/lib*/*"
#endif
/* Maximum accepted message length */
#define AGENT_MAX_LEN (256*1024)
typedef enum {
AUTH_UNUSED,
AUTH_SOCKET,
@ -634,30 +640,46 @@ send:
/* dispatch incoming messages */
static void
process_message(SocketEntry *e)
static int
process_message(u_int socknum)
{
u_int msg_len;
u_char type;
const u_char *cp;
int r;
SocketEntry *e;
if (socknum >= sockets_alloc) {
fatal("%s: socket number %u >= allocated %u",
__func__, socknum, sockets_alloc);
}
e = &sockets[socknum];
if (sshbuf_len(e->input) < 5)
return; /* Incomplete message. */
return 0; /* Incomplete message header. */
cp = sshbuf_ptr(e->input);
msg_len = PEEK_U32(cp);
if (msg_len > 256 * 1024) {
close_socket(e);
return;
if (msg_len > AGENT_MAX_LEN) {
debug("%s: socket %u (fd=%d) message too long %u > %u",
__func__, socknum, e->fd, msg_len, AGENT_MAX_LEN);
return -1;
}
if (sshbuf_len(e->input) < msg_len + 4)
return;
return 0; /* Incomplete message body. */
/* move the current input to e->request */
sshbuf_reset(e->request);
if ((r = sshbuf_get_stringb(e->input, e->request)) != 0 ||
(r = sshbuf_get_u8(e->request, &type)) != 0)
(r = sshbuf_get_u8(e->request, &type)) != 0) {
if (r == SSH_ERR_MESSAGE_INCOMPLETE ||
r == SSH_ERR_STRING_TOO_LARGE) {
debug("%s: buffer error: %s", __func__, ssh_err(r));
return -1;
}
fatal("%s: buffer error: %s", __func__, ssh_err(r));
}
debug("%s: socket %u (fd=%d) type %d", __func__, socknum, e->fd, type);
/* check wheter agent is locked */
if (locked && type != SSH_AGENTC_UNLOCK) {
@ -671,10 +693,9 @@ process_message(SocketEntry *e)
/* send a fail message for all other request types */
send_status(e, 0);
}
return;
return 0;
}
debug("type %d", type);
switch (type) {
case SSH_AGENTC_LOCK:
case SSH_AGENTC_UNLOCK:
@ -716,6 +737,7 @@ process_message(SocketEntry *e)
send_status(e, 0);
break;
}
return 0;
}
static void
@ -757,19 +779,141 @@ new_socket(sock_type type, int fd)
}
static int
prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
struct timeval **tvpp)
handle_socket_read(u_int socknum)
{
u_int i, sz;
int n = 0;
static struct timeval tv;
struct sockaddr_un sunaddr;
socklen_t slen;
uid_t euid;
gid_t egid;
int fd;
slen = sizeof(sunaddr);
fd = accept(sockets[socknum].fd, (struct sockaddr *)&sunaddr, &slen);
if (fd < 0) {
error("accept from AUTH_SOCKET: %s", strerror(errno));
return -1;
}
if (getpeereid(fd, &euid, &egid) < 0) {
error("getpeereid %d failed: %s", fd, strerror(errno));
close(fd);
return -1;
}
if ((euid != 0) && (getuid() != euid)) {
error("uid mismatch: peer euid %u != uid %u",
(u_int) euid, (u_int) getuid());
close(fd);
return -1;
}
new_socket(AUTH_CONNECTION, fd);
return 0;
}
static int
handle_conn_read(u_int socknum)
{
char buf[1024];
ssize_t len;
int r;
if ((len = read(sockets[socknum].fd, buf, sizeof(buf))) <= 0) {
if (len == -1) {
if (errno == EAGAIN || errno == EINTR)
return 0;
error("%s: read error on socket %u (fd %d): %s",
__func__, socknum, sockets[socknum].fd,
strerror(errno));
}
return -1;
}
if ((r = sshbuf_put(sockets[socknum].input, buf, len)) != 0)
fatal("%s: buffer error: %s", __func__, ssh_err(r));
explicit_bzero(buf, sizeof(buf));
process_message(socknum);
return 0;
}
static int
handle_conn_write(u_int socknum)
{
ssize_t len;
int r;
if (sshbuf_len(sockets[socknum].output) == 0)
return 0; /* shouldn't happen */
if ((len = write(sockets[socknum].fd,
sshbuf_ptr(sockets[socknum].output),
sshbuf_len(sockets[socknum].output))) <= 0) {
if (len == -1) {
if (errno == EAGAIN || errno == EINTR)
return 0;
error("%s: read error on socket %u (fd %d): %s",
__func__, socknum, sockets[socknum].fd,
strerror(errno));
}
return -1;
}
if ((r = sshbuf_consume(sockets[socknum].output, len)) != 0)
fatal("%s: buffer error: %s", __func__, ssh_err(r));
return 0;
}
static void
after_poll(struct pollfd *pfd, size_t npfd)
{
size_t i;
u_int socknum;
for (i = 0; i < npfd; i++) {
if (pfd[i].revents == 0)
continue;
/* Find sockets entry */
for (socknum = 0; socknum < sockets_alloc; socknum++) {
if (sockets[socknum].type != AUTH_SOCKET &&
sockets[socknum].type != AUTH_CONNECTION)
continue;
if (pfd[i].fd == sockets[socknum].fd)
break;
}
if (socknum >= sockets_alloc) {
error("%s: no socket for fd %d", __func__, pfd[i].fd);
continue;
}
/* Process events */
switch (sockets[socknum].type) {
case AUTH_SOCKET:
if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 &&
handle_socket_read(socknum) != 0)
close_socket(&sockets[socknum]);
break;
case AUTH_CONNECTION:
if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 &&
handle_conn_read(socknum) != 0) {
close_socket(&sockets[socknum]);
break;
}
if ((pfd[i].revents & (POLLOUT|POLLHUP)) != 0 &&
handle_conn_write(socknum) != 0)
close_socket(&sockets[socknum]);
break;
default:
break;
}
}
}
static int
prepare_poll(struct pollfd **pfdp, size_t *npfdp, int *timeoutp)
{
struct pollfd *pfd = *pfdp;
size_t i, j, npfd = 0;
time_t deadline;
/* Count active sockets */
for (i = 0; i < sockets_alloc; i++) {
switch (sockets[i].type) {
case AUTH_SOCKET:
case AUTH_CONNECTION:
n = MAXIMUM(n, sockets[i].fd);
npfd++;
break;
case AUTH_UNUSED:
break;
@ -778,28 +922,23 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
break;
}
}
if (npfd != *npfdp &&
(pfd = recallocarray(pfd, *npfdp, npfd, sizeof(*pfd))) == NULL)
fatal("%s: recallocarray failed", __func__);
*pfdp = pfd;
*npfdp = npfd;
sz = howmany(n+1, NFDBITS) * sizeof(fd_mask);
if (*fdrp == NULL || sz > *nallocp) {
free(*fdrp);
free(*fdwp);
*fdrp = xmalloc(sz);
*fdwp = xmalloc(sz);
*nallocp = sz;
}
if (n < *fdl)
debug("XXX shrink: %d < %d", n, *fdl);
*fdl = n;
memset(*fdrp, 0, sz);
memset(*fdwp, 0, sz);
for (i = 0; i < sockets_alloc; i++) {
for (i = j = 0; i < sockets_alloc; i++) {
switch (sockets[i].type) {
case AUTH_SOCKET:
case AUTH_CONNECTION:
FD_SET(sockets[i].fd, *fdrp);
pfd[j].fd = sockets[i].fd;
pfd[j].revents = 0;
/* XXX backoff when input buffer full */
pfd[j].events = POLLIN;
if (sshbuf_len(sockets[i].output) > 0)
FD_SET(sockets[i].fd, *fdwp);
pfd[j].events |= POLLOUT;
j++;
break;
default:
break;
@ -810,98 +949,16 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
deadline = (deadline == 0) ? parent_alive_interval :
MINIMUM(deadline, parent_alive_interval);
if (deadline == 0) {
*tvpp = NULL;
*timeoutp = INFTIM;
} else {
tv.tv_sec = deadline;
tv.tv_usec = 0;
*tvpp = &tv;
if (deadline > INT_MAX / 1000)
*timeoutp = INT_MAX / 1000;
else
*timeoutp = deadline * 1000;
}
return (1);
}
static void
after_select(fd_set *readset, fd_set *writeset)
{
struct sockaddr_un sunaddr;
socklen_t slen;
char buf[1024];
int len, sock, r;
u_int i, orig_alloc;
uid_t euid;
gid_t egid;
for (i = 0, orig_alloc = sockets_alloc; i < orig_alloc; i++)
switch (sockets[i].type) {
case AUTH_UNUSED:
break;
case AUTH_SOCKET:
if (FD_ISSET(sockets[i].fd, readset)) {
slen = sizeof(sunaddr);
sock = accept(sockets[i].fd,
(struct sockaddr *)&sunaddr, &slen);
if (sock < 0) {
error("accept from AUTH_SOCKET: %s",
strerror(errno));
break;
}
if (getpeereid(sock, &euid, &egid) < 0) {
error("getpeereid %d failed: %s",
sock, strerror(errno));
close(sock);
break;
}
if ((euid != 0) && (getuid() != euid)) {
error("uid mismatch: "
"peer euid %u != uid %u",
(u_int) euid, (u_int) getuid());
close(sock);
break;
}
new_socket(AUTH_CONNECTION, sock);
}
break;
case AUTH_CONNECTION:
if (sshbuf_len(sockets[i].output) > 0 &&
FD_ISSET(sockets[i].fd, writeset)) {
len = write(sockets[i].fd,
sshbuf_ptr(sockets[i].output),
sshbuf_len(sockets[i].output));
if (len == -1 && (errno == EAGAIN ||
errno == EWOULDBLOCK ||
errno == EINTR))
continue;
if (len <= 0) {
close_socket(&sockets[i]);
break;
}
if ((r = sshbuf_consume(sockets[i].output,
len)) != 0)
fatal("%s: buffer error: %s",
__func__, ssh_err(r));
}
if (FD_ISSET(sockets[i].fd, readset)) {
len = read(sockets[i].fd, buf, sizeof(buf));
if (len == -1 && (errno == EAGAIN ||
errno == EWOULDBLOCK ||
errno == EINTR))
continue;
if (len <= 0) {
close_socket(&sockets[i]);
break;
}
if ((r = sshbuf_put(sockets[i].input,
buf, len)) != 0)
fatal("%s: buffer error: %s",
__func__, ssh_err(r));
explicit_bzero(buf, sizeof(buf));
process_message(&sockets[i]);
}
break;
default:
fatal("Unknown type %d", sockets[i].type);
}
}
static void
cleanup_socket(void)
{
@ -963,7 +1020,6 @@ main(int ac, char **av)
int sock, fd, ch, result, saved_errno;
u_int nalloc;
char *shell, *format, *pidstr, *agentsocket = NULL;
fd_set *readsetp = NULL, *writesetp = NULL;
#ifdef HAVE_SETRLIMIT
struct rlimit rlim;
#endif
@ -971,9 +1027,11 @@ main(int ac, char **av)
extern char *optarg;
pid_t pid;
char pidstrbuf[1 + 3 * sizeof pid];
struct timeval *tvp = NULL;
size_t len;
mode_t prev_mask;
int timeout = INFTIM;
struct pollfd *pfd = NULL;
size_t npfd = 0;
ssh_malloc_init(); /* must be called before any mallocs */
/* Ensure that fds 0, 1 and 2 are open or directed to /dev/null */
@ -1201,8 +1259,8 @@ skip:
platform_pledge_agent();
while (1) {
prepare_select(&readsetp, &writesetp, &max_fd, &nalloc, &tvp);
result = select(max_fd + 1, readsetp, writesetp, NULL, tvp);
prepare_poll(&pfd, &npfd, &timeout);
result = poll(pfd, npfd, timeout);
saved_errno = errno;
if (parent_alive_interval != 0)
check_parent_exists();
@ -1210,9 +1268,9 @@ skip:
if (result < 0) {
if (saved_errno == EINTR)
continue;
fatal("select: %s", strerror(saved_errno));
fatal("poll: %s", strerror(saved_errno));
} else if (result > 0)
after_select(readsetp, writesetp);
after_poll(pfd, npfd);
}
/* NOTREACHED */
}