Bugfixes for SSL sockets.

This commit is contained in:
Gunnar Beutner 2012-07-16 11:44:11 +02:00
parent dd26fd46f5
commit 6ebb1bf192
4 changed files with 52 additions and 21 deletions

View File

@ -316,12 +316,8 @@ void Socket::ReadThreadProc(void)
return;
}
if (FD_ISSET(fd, &readfds)) {
if (!m_Connected)
m_Connected = true;
if (FD_ISSET(fd, &readfds))
HandleReadable();
}
if (FD_ISSET(fd, &exceptfds))
HandleException();
@ -340,7 +336,7 @@ void Socket::WriteThreadProc(void)
FD_ZERO(&writefds);
while (!WantsToWrite() && m_Connected) {
while (!WantsToWrite()) {
m_WriteCV.timed_wait(lock, boost::posix_time::seconds(1));
if (GetFD() == INVALID_SOCKET)
@ -368,12 +364,8 @@ void Socket::WriteThreadProc(void)
return;
}
if (FD_ISSET(fd, &writefds)) {
if (!m_Connected)
m_Connected = true;
if (FD_ISSET(fd, &writefds))
HandleWritable();
}
}
}
@ -381,3 +373,13 @@ mutex& Socket::GetMutex(void) const
{
return m_Mutex;
}
void Socket::SetConnected(bool connected)
{
m_Connected = connected;
}
bool Socket::IsConnected(void) const
{
return m_Connected;
}

View File

@ -53,6 +53,9 @@ protected:
void SetFD(SOCKET fd);
SOCKET GetFD(void) const;
void SetConnected(bool connected);
bool IsConnected(void) const;
int GetError(void) const;
static int GetLastSocketError(void);
void HandleSocketError(const exception& ex);

View File

@ -120,10 +120,14 @@ void TcpClient::HandleWritable(void)
rc = send(GetFD(), (const char *)data, count, 0);
if (rc <= 0) {
SetConnected(false);
HandleSocketError(SocketException("send() failed", GetError()));
return;
}
SetConnected(true);
m_SendQueue->Read(NULL, rc);
}
}
@ -182,10 +186,14 @@ void TcpClient::HandleReadable(void)
return;
if (rc <= 0) {
SetConnected(false);
HandleSocketError(SocketException("recv() failed", GetError()));
return;
}
SetConnected(true);
m_RecvQueue->Write(data, rc);
}

View File

@ -118,7 +118,13 @@ void TlsClient::HandleReadable(void)
for (;;) {
char data[1024];
int rc = SSL_read(m_SSL.get(), data, sizeof(data));
int rc;
if (IsConnected()) {
rc = SSL_read(m_SSL.get(), data, sizeof(data));
} else {
rc = SSL_do_handshake(m_SSL.get());
}
if (rc <= 0) {
switch (SSL_get_error(m_SSL.get(), rc)) {
@ -137,7 +143,10 @@ void TlsClient::HandleReadable(void)
}
}
m_RecvQueue->Write(data, rc);
if (IsConnected())
m_RecvQueue->Write(data, rc);
else
SetConnected(true);
}
post_event:
@ -156,17 +165,23 @@ void TlsClient::HandleWritable(void)
size_t count;
for (;;) {
count = m_SendQueue->GetAvailableBytes();
int rc;
if (count == 0)
break;
if (IsConnected()) {
count = m_SendQueue->GetAvailableBytes();
if (count > sizeof(data))
count = sizeof(data);
if (count == 0)
break;
m_SendQueue->Peek(data, count);
if (count > sizeof(data))
count = sizeof(data);
int rc = SSL_write(m_SSL.get(), (const char *)data, count);
m_SendQueue->Peek(data, count);
rc = SSL_write(m_SSL.get(), (const char *)data, count);
} else {
rc = SSL_do_handshake(m_SSL.get());
}
if (rc <= 0) {
switch (SSL_get_error(m_SSL.get(), rc)) {
@ -185,7 +200,10 @@ void TlsClient::HandleWritable(void)
}
}
m_SendQueue->Read(NULL, rc);
if (IsConnected())
m_SendQueue->Read(NULL, rc);
else
SetConnected(true);
}
}