diff --git a/ssh-agent.c b/ssh-agent.c index eb8c2043d..d858c2470 100644 --- a/ssh-agent.c +++ b/ssh-agent.c @@ -1,4 +1,4 @@ -/* $OpenBSD: ssh-agent.c,v 1.222 2017/07/01 13:50:45 djm Exp $ */ +/* $OpenBSD: ssh-agent.c,v 1.223 2017/07/19 01:15:02 djm Exp $ */ /* * Author: Tatu Ylonen * Copyright (c) 1995 Tatu Ylonen , Espoo, Finland @@ -60,6 +60,9 @@ #ifdef HAVE_PATHS_H # include #endif +#ifdef HAVE_POLL_H +# include +#endif #include #include #include @@ -91,6 +94,9 @@ # define DEFAULT_PKCS11_WHITELIST "/usr/lib*/*,/usr/local/lib*/*" #endif +/* Maximum accepted message length */ +#define AGENT_MAX_LEN (256*1024) + typedef enum { AUTH_UNUSED, AUTH_SOCKET, @@ -634,30 +640,46 @@ send: /* dispatch incoming messages */ -static void -process_message(SocketEntry *e) +static int +process_message(u_int socknum) { u_int msg_len; u_char type; const u_char *cp; int r; + SocketEntry *e; + + if (socknum >= sockets_alloc) { + fatal("%s: socket number %u >= allocated %u", + __func__, socknum, sockets_alloc); + } + e = &sockets[socknum]; if (sshbuf_len(e->input) < 5) - return; /* Incomplete message. */ + return 0; /* Incomplete message header. */ cp = sshbuf_ptr(e->input); msg_len = PEEK_U32(cp); - if (msg_len > 256 * 1024) { - close_socket(e); - return; + if (msg_len > AGENT_MAX_LEN) { + debug("%s: socket %u (fd=%d) message too long %u > %u", + __func__, socknum, e->fd, msg_len, AGENT_MAX_LEN); + return -1; } if (sshbuf_len(e->input) < msg_len + 4) - return; + return 0; /* Incomplete message body. */ /* move the current input to e->request */ sshbuf_reset(e->request); if ((r = sshbuf_get_stringb(e->input, e->request)) != 0 || - (r = sshbuf_get_u8(e->request, &type)) != 0) + (r = sshbuf_get_u8(e->request, &type)) != 0) { + if (r == SSH_ERR_MESSAGE_INCOMPLETE || + r == SSH_ERR_STRING_TOO_LARGE) { + debug("%s: buffer error: %s", __func__, ssh_err(r)); + return -1; + } fatal("%s: buffer error: %s", __func__, ssh_err(r)); + } + + debug("%s: socket %u (fd=%d) type %d", __func__, socknum, e->fd, type); /* check wheter agent is locked */ if (locked && type != SSH_AGENTC_UNLOCK) { @@ -671,10 +693,9 @@ process_message(SocketEntry *e) /* send a fail message for all other request types */ send_status(e, 0); } - return; + return 0; } - debug("type %d", type); switch (type) { case SSH_AGENTC_LOCK: case SSH_AGENTC_UNLOCK: @@ -716,6 +737,7 @@ process_message(SocketEntry *e) send_status(e, 0); break; } + return 0; } static void @@ -757,19 +779,141 @@ new_socket(sock_type type, int fd) } static int -prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp, - struct timeval **tvpp) +handle_socket_read(u_int socknum) { - u_int i, sz; - int n = 0; - static struct timeval tv; + struct sockaddr_un sunaddr; + socklen_t slen; + uid_t euid; + gid_t egid; + int fd; + + slen = sizeof(sunaddr); + fd = accept(sockets[socknum].fd, (struct sockaddr *)&sunaddr, &slen); + if (fd < 0) { + error("accept from AUTH_SOCKET: %s", strerror(errno)); + return -1; + } + if (getpeereid(fd, &euid, &egid) < 0) { + error("getpeereid %d failed: %s", fd, strerror(errno)); + close(fd); + return -1; + } + if ((euid != 0) && (getuid() != euid)) { + error("uid mismatch: peer euid %u != uid %u", + (u_int) euid, (u_int) getuid()); + close(fd); + return -1; + } + new_socket(AUTH_CONNECTION, fd); + return 0; +} + +static int +handle_conn_read(u_int socknum) +{ + char buf[1024]; + ssize_t len; + int r; + + if ((len = read(sockets[socknum].fd, buf, sizeof(buf))) <= 0) { + if (len == -1) { + if (errno == EAGAIN || errno == EINTR) + return 0; + error("%s: read error on socket %u (fd %d): %s", + __func__, socknum, sockets[socknum].fd, + strerror(errno)); + } + return -1; + } + if ((r = sshbuf_put(sockets[socknum].input, buf, len)) != 0) + fatal("%s: buffer error: %s", __func__, ssh_err(r)); + explicit_bzero(buf, sizeof(buf)); + process_message(socknum); + return 0; +} + +static int +handle_conn_write(u_int socknum) +{ + ssize_t len; + int r; + + if (sshbuf_len(sockets[socknum].output) == 0) + return 0; /* shouldn't happen */ + if ((len = write(sockets[socknum].fd, + sshbuf_ptr(sockets[socknum].output), + sshbuf_len(sockets[socknum].output))) <= 0) { + if (len == -1) { + if (errno == EAGAIN || errno == EINTR) + return 0; + error("%s: read error on socket %u (fd %d): %s", + __func__, socknum, sockets[socknum].fd, + strerror(errno)); + } + return -1; + } + if ((r = sshbuf_consume(sockets[socknum].output, len)) != 0) + fatal("%s: buffer error: %s", __func__, ssh_err(r)); + return 0; +} + +static void +after_poll(struct pollfd *pfd, size_t npfd) +{ + size_t i; + u_int socknum; + + for (i = 0; i < npfd; i++) { + if (pfd[i].revents == 0) + continue; + /* Find sockets entry */ + for (socknum = 0; socknum < sockets_alloc; socknum++) { + if (sockets[socknum].type != AUTH_SOCKET && + sockets[socknum].type != AUTH_CONNECTION) + continue; + if (pfd[i].fd == sockets[socknum].fd) + break; + } + if (socknum >= sockets_alloc) { + error("%s: no socket for fd %d", __func__, pfd[i].fd); + continue; + } + /* Process events */ + switch (sockets[socknum].type) { + case AUTH_SOCKET: + if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 && + handle_socket_read(socknum) != 0) + close_socket(&sockets[socknum]); + break; + case AUTH_CONNECTION: + if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 && + handle_conn_read(socknum) != 0) { + close_socket(&sockets[socknum]); + break; + } + if ((pfd[i].revents & (POLLOUT|POLLHUP)) != 0 && + handle_conn_write(socknum) != 0) + close_socket(&sockets[socknum]); + break; + default: + break; + } + } +} + +static int +prepare_poll(struct pollfd **pfdp, size_t *npfdp, int *timeoutp) +{ + struct pollfd *pfd = *pfdp; + size_t i, j, npfd = 0; time_t deadline; + /* Count active sockets */ for (i = 0; i < sockets_alloc; i++) { switch (sockets[i].type) { case AUTH_SOCKET: case AUTH_CONNECTION: - n = MAXIMUM(n, sockets[i].fd); + npfd++; break; case AUTH_UNUSED: break; @@ -778,28 +922,23 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp, break; } } + if (npfd != *npfdp && + (pfd = recallocarray(pfd, *npfdp, npfd, sizeof(*pfd))) == NULL) + fatal("%s: recallocarray failed", __func__); + *pfdp = pfd; + *npfdp = npfd; - sz = howmany(n+1, NFDBITS) * sizeof(fd_mask); - if (*fdrp == NULL || sz > *nallocp) { - free(*fdrp); - free(*fdwp); - *fdrp = xmalloc(sz); - *fdwp = xmalloc(sz); - *nallocp = sz; - } - if (n < *fdl) - debug("XXX shrink: %d < %d", n, *fdl); - *fdl = n; - memset(*fdrp, 0, sz); - memset(*fdwp, 0, sz); - - for (i = 0; i < sockets_alloc; i++) { + for (i = j = 0; i < sockets_alloc; i++) { switch (sockets[i].type) { case AUTH_SOCKET: case AUTH_CONNECTION: - FD_SET(sockets[i].fd, *fdrp); + pfd[j].fd = sockets[i].fd; + pfd[j].revents = 0; + /* XXX backoff when input buffer full */ + pfd[j].events = POLLIN; if (sshbuf_len(sockets[i].output) > 0) - FD_SET(sockets[i].fd, *fdwp); + pfd[j].events |= POLLOUT; + j++; break; default: break; @@ -810,98 +949,16 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp, deadline = (deadline == 0) ? parent_alive_interval : MINIMUM(deadline, parent_alive_interval); if (deadline == 0) { - *tvpp = NULL; + *timeoutp = INFTIM; } else { - tv.tv_sec = deadline; - tv.tv_usec = 0; - *tvpp = &tv; + if (deadline > INT_MAX / 1000) + *timeoutp = INT_MAX / 1000; + else + *timeoutp = deadline * 1000; } return (1); } -static void -after_select(fd_set *readset, fd_set *writeset) -{ - struct sockaddr_un sunaddr; - socklen_t slen; - char buf[1024]; - int len, sock, r; - u_int i, orig_alloc; - uid_t euid; - gid_t egid; - - for (i = 0, orig_alloc = sockets_alloc; i < orig_alloc; i++) - switch (sockets[i].type) { - case AUTH_UNUSED: - break; - case AUTH_SOCKET: - if (FD_ISSET(sockets[i].fd, readset)) { - slen = sizeof(sunaddr); - sock = accept(sockets[i].fd, - (struct sockaddr *)&sunaddr, &slen); - if (sock < 0) { - error("accept from AUTH_SOCKET: %s", - strerror(errno)); - break; - } - if (getpeereid(sock, &euid, &egid) < 0) { - error("getpeereid %d failed: %s", - sock, strerror(errno)); - close(sock); - break; - } - if ((euid != 0) && (getuid() != euid)) { - error("uid mismatch: " - "peer euid %u != uid %u", - (u_int) euid, (u_int) getuid()); - close(sock); - break; - } - new_socket(AUTH_CONNECTION, sock); - } - break; - case AUTH_CONNECTION: - if (sshbuf_len(sockets[i].output) > 0 && - FD_ISSET(sockets[i].fd, writeset)) { - len = write(sockets[i].fd, - sshbuf_ptr(sockets[i].output), - sshbuf_len(sockets[i].output)); - if (len == -1 && (errno == EAGAIN || - errno == EWOULDBLOCK || - errno == EINTR)) - continue; - if (len <= 0) { - close_socket(&sockets[i]); - break; - } - if ((r = sshbuf_consume(sockets[i].output, - len)) != 0) - fatal("%s: buffer error: %s", - __func__, ssh_err(r)); - } - if (FD_ISSET(sockets[i].fd, readset)) { - len = read(sockets[i].fd, buf, sizeof(buf)); - if (len == -1 && (errno == EAGAIN || - errno == EWOULDBLOCK || - errno == EINTR)) - continue; - if (len <= 0) { - close_socket(&sockets[i]); - break; - } - if ((r = sshbuf_put(sockets[i].input, - buf, len)) != 0) - fatal("%s: buffer error: %s", - __func__, ssh_err(r)); - explicit_bzero(buf, sizeof(buf)); - process_message(&sockets[i]); - } - break; - default: - fatal("Unknown type %d", sockets[i].type); - } -} - static void cleanup_socket(void) { @@ -963,7 +1020,6 @@ main(int ac, char **av) int sock, fd, ch, result, saved_errno; u_int nalloc; char *shell, *format, *pidstr, *agentsocket = NULL; - fd_set *readsetp = NULL, *writesetp = NULL; #ifdef HAVE_SETRLIMIT struct rlimit rlim; #endif @@ -971,9 +1027,11 @@ main(int ac, char **av) extern char *optarg; pid_t pid; char pidstrbuf[1 + 3 * sizeof pid]; - struct timeval *tvp = NULL; size_t len; mode_t prev_mask; + int timeout = INFTIM; + struct pollfd *pfd = NULL; + size_t npfd = 0; ssh_malloc_init(); /* must be called before any mallocs */ /* Ensure that fds 0, 1 and 2 are open or directed to /dev/null */ @@ -1201,8 +1259,8 @@ skip: platform_pledge_agent(); while (1) { - prepare_select(&readsetp, &writesetp, &max_fd, &nalloc, &tvp); - result = select(max_fd + 1, readsetp, writesetp, NULL, tvp); + prepare_poll(&pfd, &npfd, &timeout); + result = poll(pfd, npfd, timeout); saved_errno = errno; if (parent_alive_interval != 0) check_parent_exists(); @@ -1210,9 +1268,9 @@ skip: if (result < 0) { if (saved_errno == EINTR) continue; - fatal("select: %s", strerror(saved_errno)); + fatal("poll: %s", strerror(saved_errno)); } else if (result > 0) - after_select(readsetp, writesetp); + after_poll(pfd, npfd); } /* NOTREACHED */ }