diff --git a/base/socket.cpp b/base/socket.cpp index 810ce314b..7ca0f949b 100644 --- a/base/socket.cpp +++ b/base/socket.cpp @@ -316,12 +316,8 @@ void Socket::ReadThreadProc(void) return; } - if (FD_ISSET(fd, &readfds)) { - if (!m_Connected) - m_Connected = true; - + if (FD_ISSET(fd, &readfds)) HandleReadable(); - } if (FD_ISSET(fd, &exceptfds)) HandleException(); @@ -340,7 +336,7 @@ void Socket::WriteThreadProc(void) FD_ZERO(&writefds); - while (!WantsToWrite() && m_Connected) { + while (!WantsToWrite()) { m_WriteCV.timed_wait(lock, boost::posix_time::seconds(1)); if (GetFD() == INVALID_SOCKET) @@ -368,12 +364,8 @@ void Socket::WriteThreadProc(void) return; } - if (FD_ISSET(fd, &writefds)) { - if (!m_Connected) - m_Connected = true; - + if (FD_ISSET(fd, &writefds)) HandleWritable(); - } } } @@ -381,3 +373,13 @@ mutex& Socket::GetMutex(void) const { return m_Mutex; } + +void Socket::SetConnected(bool connected) +{ + m_Connected = connected; +} + +bool Socket::IsConnected(void) const +{ + return m_Connected; +} diff --git a/base/socket.h b/base/socket.h index e046f5a29..5eeb0fe5b 100644 --- a/base/socket.h +++ b/base/socket.h @@ -53,6 +53,9 @@ protected: void SetFD(SOCKET fd); SOCKET GetFD(void) const; + void SetConnected(bool connected); + bool IsConnected(void) const; + int GetError(void) const; static int GetLastSocketError(void); void HandleSocketError(const exception& ex); diff --git a/base/tcpclient.cpp b/base/tcpclient.cpp index ce9c89a42..c797f3bcd 100644 --- a/base/tcpclient.cpp +++ b/base/tcpclient.cpp @@ -120,10 +120,14 @@ void TcpClient::HandleWritable(void) rc = send(GetFD(), (const char *)data, count, 0); if (rc <= 0) { + SetConnected(false); + HandleSocketError(SocketException("send() failed", GetError())); return; } + SetConnected(true); + m_SendQueue->Read(NULL, rc); } } @@ -182,10 +186,14 @@ void TcpClient::HandleReadable(void) return; if (rc <= 0) { + SetConnected(false); + HandleSocketError(SocketException("recv() failed", GetError())); return; } + SetConnected(true); + m_RecvQueue->Write(data, rc); } diff --git a/base/tlsclient.cpp b/base/tlsclient.cpp index 593bfed85..c948e8903 100644 --- a/base/tlsclient.cpp +++ b/base/tlsclient.cpp @@ -118,7 +118,13 @@ void TlsClient::HandleReadable(void) for (;;) { char data[1024]; - int rc = SSL_read(m_SSL.get(), data, sizeof(data)); + int rc; + + if (IsConnected()) { + rc = SSL_read(m_SSL.get(), data, sizeof(data)); + } else { + rc = SSL_do_handshake(m_SSL.get()); + } if (rc <= 0) { switch (SSL_get_error(m_SSL.get(), rc)) { @@ -137,7 +143,10 @@ void TlsClient::HandleReadable(void) } } - m_RecvQueue->Write(data, rc); + if (IsConnected()) + m_RecvQueue->Write(data, rc); + else + SetConnected(true); } post_event: @@ -156,17 +165,23 @@ void TlsClient::HandleWritable(void) size_t count; for (;;) { - count = m_SendQueue->GetAvailableBytes(); + int rc; - if (count == 0) - break; + if (IsConnected()) { + count = m_SendQueue->GetAvailableBytes(); - if (count > sizeof(data)) - count = sizeof(data); + if (count == 0) + break; - m_SendQueue->Peek(data, count); + if (count > sizeof(data)) + count = sizeof(data); - int rc = SSL_write(m_SSL.get(), (const char *)data, count); + m_SendQueue->Peek(data, count); + + rc = SSL_write(m_SSL.get(), (const char *)data, count); + } else { + rc = SSL_do_handshake(m_SSL.get()); + } if (rc <= 0) { switch (SSL_get_error(m_SSL.get(), rc)) { @@ -185,7 +200,10 @@ void TlsClient::HandleWritable(void) } } - m_SendQueue->Read(NULL, rc); + if (IsConnected()) + m_SendQueue->Read(NULL, rc); + else + SetConnected(true); } }