upstream: Apply ConnectTimeout to multiplexing local socket

connections.  If the multiplex socket exists but the connection times out,
ssh will fall back to a direct connection the same way it would if the socket
did not exist at all.  ok djm@

OpenBSD-Commit-ID: 2fbe1a36d4a24b98531b2d298a6557c8285dc1b4
This commit is contained in:
dtucker@openbsd.org 2023-08-04 06:32:40 +00:00 committed by Darren Tucker
parent 9d92e7b248
commit e535fbe2af
No known key found for this signature in database
2 changed files with 27 additions and 17 deletions

7
misc.c
View File

@ -1,4 +1,4 @@
/* $OpenBSD: misc.c,v 1.184 2023/07/19 14:02:27 djm Exp $ */ /* $OpenBSD: misc.c,v 1.185 2023/08/04 06:32:40 dtucker Exp $ */
/* /*
* Copyright (c) 2000 Markus Friedl. All rights reserved. * Copyright (c) 2000 Markus Friedl. All rights reserved.
* Copyright (c) 2005-2020 Damien Miller. All rights reserved. * Copyright (c) 2005-2020 Damien Miller. All rights reserved.
@ -317,14 +317,15 @@ waitfd(int fd, int *timeoutp, short events)
{ {
struct pollfd pfd; struct pollfd pfd;
struct timeval t_start; struct timeval t_start;
int oerrno, r; int oerrno, r, have_timeout = (*timeoutp >= 0);
pfd.fd = fd; pfd.fd = fd;
pfd.events = events; pfd.events = events;
for (; *timeoutp >= 0;) { for (; !have_timeout || *timeoutp >= 0;) {
monotime_tv(&t_start); monotime_tv(&t_start);
r = poll(&pfd, 1, *timeoutp); r = poll(&pfd, 1, *timeoutp);
oerrno = errno; oerrno = errno;
if (have_timeout)
ms_subtract_diff(&t_start, timeoutp); ms_subtract_diff(&t_start, timeoutp);
errno = oerrno; errno = oerrno;
if (r > 0) if (r > 0)

35
mux.c
View File

@ -1,4 +1,4 @@
/* $OpenBSD: mux.c,v 1.98 2023/07/26 23:06:00 djm Exp $ */ /* $OpenBSD: mux.c,v 1.99 2023/08/04 06:32:40 dtucker Exp $ */
/* /*
* Copyright (c) 2002-2008 Damien Miller <djm@openbsd.org> * Copyright (c) 2002-2008 Damien Miller <djm@openbsd.org>
* *
@ -68,6 +68,7 @@
#include "readconf.h" #include "readconf.h"
#include "clientloop.h" #include "clientloop.h"
#include "ssherr.h" #include "ssherr.h"
#include "misc.h"
/* from ssh.c */ /* from ssh.c */
extern int tty_flag; extern int tty_flag;
@ -1458,16 +1459,13 @@ control_client_sigrelay(int signo)
} }
static int static int
mux_client_read(int fd, struct sshbuf *b, size_t need) mux_client_read(int fd, struct sshbuf *b, size_t need, int timeout_ms)
{ {
size_t have; size_t have;
ssize_t len; ssize_t len;
u_char *p; u_char *p;
struct pollfd pfd;
int r; int r;
pfd.fd = fd;
pfd.events = POLLIN;
if ((r = sshbuf_reserve(b, need, &p)) != 0) if ((r = sshbuf_reserve(b, need, &p)) != 0)
fatal_fr(r, "reserve"); fatal_fr(r, "reserve");
for (have = 0; have < need; ) { for (have = 0; have < need; ) {
@ -1482,7 +1480,8 @@ mux_client_read(int fd, struct sshbuf *b, size_t need)
case EWOULDBLOCK: case EWOULDBLOCK:
#endif #endif
case EAGAIN: case EAGAIN:
(void)poll(&pfd, 1, -1); if (waitrfd(fd, &timeout_ms) == -1)
return -1; /* timeout */
/* FALLTHROUGH */ /* FALLTHROUGH */
case EINTR: case EINTR:
continue; continue;
@ -1554,7 +1553,7 @@ mux_client_write_packet(int fd, struct sshbuf *m)
} }
static int static int
mux_client_read_packet(int fd, struct sshbuf *m) mux_client_read_packet_timeout(int fd, struct sshbuf *m, int timeout_ms)
{ {
struct sshbuf *queue; struct sshbuf *queue;
size_t need, have; size_t need, have;
@ -1563,7 +1562,7 @@ mux_client_read_packet(int fd, struct sshbuf *m)
if ((queue = sshbuf_new()) == NULL) if ((queue = sshbuf_new()) == NULL)
fatal_f("sshbuf_new"); fatal_f("sshbuf_new");
if (mux_client_read(fd, queue, 4) != 0) { if (mux_client_read(fd, queue, 4, timeout_ms) != 0) {
if ((oerrno = errno) == EPIPE) if ((oerrno = errno) == EPIPE)
debug3_f("read header failed: %s", debug3_f("read header failed: %s",
strerror(errno)); strerror(errno));
@ -1572,7 +1571,7 @@ mux_client_read_packet(int fd, struct sshbuf *m)
return -1; return -1;
} }
need = PEEK_U32(sshbuf_ptr(queue)); need = PEEK_U32(sshbuf_ptr(queue));
if (mux_client_read(fd, queue, need) != 0) { if (mux_client_read(fd, queue, need, timeout_ms) != 0) {
oerrno = errno; oerrno = errno;
debug3_f("read body failed: %s", strerror(errno)); debug3_f("read body failed: %s", strerror(errno));
sshbuf_free(queue); sshbuf_free(queue);
@ -1587,7 +1586,13 @@ mux_client_read_packet(int fd, struct sshbuf *m)
} }
static int static int
mux_client_hello_exchange(int fd) mux_client_read_packet(int fd, struct sshbuf *m)
{
return mux_client_read_packet_timeout(fd, m, -1);
}
static int
mux_client_hello_exchange(int fd, int timeout_ms)
{ {
struct sshbuf *m; struct sshbuf *m;
u_int type, ver; u_int type, ver;
@ -1608,7 +1613,7 @@ mux_client_hello_exchange(int fd)
sshbuf_reset(m); sshbuf_reset(m);
/* Read their HELLO */ /* Read their HELLO */
if (mux_client_read_packet(fd, m) != 0) { if (mux_client_read_packet_timeout(fd, m, timeout_ms) != 0) {
debug_f("read packet failed"); debug_f("read packet failed");
goto out; goto out;
} }
@ -2258,7 +2263,7 @@ int
muxclient(const char *path) muxclient(const char *path)
{ {
struct sockaddr_un addr; struct sockaddr_un addr;
int sock; int sock, timeout = options.connection_timeout, timeout_ms = -1;
u_int pid; u_int pid;
if (muxclient_command == 0) { if (muxclient_command == 0) {
@ -2314,7 +2319,11 @@ muxclient(const char *path)
} }
set_nonblock(sock); set_nonblock(sock);
if (mux_client_hello_exchange(sock) != 0) { /* Timeout on initial connection only. */
if (timeout > 0 && timeout < INT_MAX / 1000)
timeout_ms = timeout * 1000;
if (mux_client_hello_exchange(sock, timeout_ms) != 0) {
error_f("master hello exchange failed"); error_f("master hello exchange failed");
close(sock); close(sock);
return -1; return -1;