From 3f647bb7797b3e71405c59eb280a4be74305c6b2 Mon Sep 17 00:00:00 2001 From: Gunnar Beutner Date: Fri, 1 Aug 2014 14:28:32 +0200 Subject: [PATCH] Fix OpenSSL errors during (re-)negotiation fixes #6724 --- lib/base/tlsstream.cpp | 51 +++++++++++++++-------------------------- lib/base/tlsstream.hpp | 3 +-- lib/base/tlsutility.cpp | 12 +++++++++- 3 files changed, 30 insertions(+), 36 deletions(-) diff --git a/lib/base/tlsstream.cpp b/lib/base/tlsstream.cpp index 4d8883a62..a01f15f13 100644 --- a/lib/base/tlsstream.cpp +++ b/lib/base/tlsstream.cpp @@ -35,7 +35,7 @@ bool I2_EXPORT TlsStream::m_SSLIndexInitialized = false; * @param role The role of the client. * @param sslContext The SSL context for the client. */ -TlsStream::TlsStream(const Socket::Ptr& socket, ConnectionRole role, shared_ptr sslContext) +TlsStream::TlsStream(const Socket::Ptr& socket, ConnectionRole role, const shared_ptr& sslContext) : m_Socket(socket), m_Role(role) { m_SSL = shared_ptr(SSL_new(sslContext.get()), SSL_free); @@ -61,9 +61,7 @@ TlsStream::TlsStream(const Socket::Ptr& socket, ConnectionRole role, shared_ptr< socket->MakeNonBlocking(); - m_BIO = BIO_new_socket(socket->GetFD(), 0); - BIO_set_nbio(m_BIO, 1); - SSL_set_bio(m_SSL.get(), m_BIO, m_BIO); + SSL_set_fd(m_SSL.get(), socket->GetFD()); if (m_Role == RoleServer) SSL_set_accept_state(m_SSL.get()); @@ -96,15 +94,12 @@ void TlsStream::Handshake(void) for (;;) { int rc, err; - { - boost::mutex::scoped_lock lock(m_SSLLock); - rc = SSL_do_handshake(m_SSL.get()); + rc = SSL_do_handshake(m_SSL.get()); - if (rc > 0) - break; + if (rc > 0) + break; - err = SSL_get_error(m_SSL.get(), rc); - } + err = SSL_get_error(m_SSL.get(), rc); switch (err) { case SSL_ERROR_WANT_READ: @@ -142,13 +137,10 @@ size_t TlsStream::Read(void *buffer, size_t count) while (left > 0) { int rc, err; - { - boost::mutex::scoped_lock lock(m_SSLLock); - rc = SSL_read(m_SSL.get(), ((char *)buffer) + (count - left), left); + rc = SSL_read(m_SSL.get(), ((char *)buffer) + (count - left), left); - if (rc <= 0) - err = SSL_get_error(m_SSL.get(), rc); - } + if (rc <= 0) + err = SSL_get_error(m_SSL.get(), rc); if (rc <= 0) { switch (err) { @@ -189,13 +181,10 @@ void TlsStream::Write(const void *buffer, size_t count) while (left > 0) { int rc, err; - { - boost::mutex::scoped_lock lock(m_SSLLock); - rc = SSL_write(m_SSL.get(), ((const char *)buffer) + (count - left), left); + rc = SSL_write(m_SSL.get(), ((const char *)buffer) + (count - left), left); - if (rc <= 0) - err = SSL_get_error(m_SSL.get(), rc); - } + if (rc <= 0) + err = SSL_get_error(m_SSL.get(), rc); if (rc <= 0) { switch (err) { @@ -235,18 +224,14 @@ void TlsStream::Close(void) for (;;) { int rc, err; - { - boost::mutex::scoped_lock lock(m_SSLLock); + do { + rc = SSL_shutdown(m_SSL.get()); + } while (rc == 0); - do { - rc = SSL_shutdown(m_SSL.get()); - } while (rc == 0); + if (rc > 0) + break; - if (rc > 0) - break; - - err = SSL_get_error(m_SSL.get(), rc); - } + err = SSL_get_error(m_SSL.get(), rc); switch (err) { case SSL_ERROR_WANT_READ: diff --git a/lib/base/tlsstream.hpp b/lib/base/tlsstream.hpp index 0d26656d0..28362b1f8 100644 --- a/lib/base/tlsstream.hpp +++ b/lib/base/tlsstream.hpp @@ -38,7 +38,7 @@ class I2_BASE_API TlsStream : public Stream public: DECLARE_PTR_TYPEDEFS(TlsStream); - TlsStream(const Socket::Ptr& socket, ConnectionRole role, shared_ptr sslContext); + TlsStream(const Socket::Ptr& socket, ConnectionRole role, const shared_ptr& sslContext); shared_ptr GetClientCertificate(void) const; shared_ptr GetPeerCertificate(void) const; @@ -53,7 +53,6 @@ public: virtual bool IsEof(void) const; private: - boost::mutex m_SSLLock; shared_ptr m_SSL; BIO *m_BIO; diff --git a/lib/base/tlsutility.cpp b/lib/base/tlsutility.cpp index 0d0dd7613..2643d55ca 100644 --- a/lib/base/tlsutility.cpp +++ b/lib/base/tlsutility.cpp @@ -33,6 +33,15 @@ static void OpenSSLLockingCallback(int mode, int type, const char *, int) l_Mutexes[type].unlock(); } +static unsigned long OpenSSLIDCallback(void) +{ +#ifdef _WIN32 + return static_cast(GetCurrentThreadId()); +#else /* _WIN32 */ + return static_cast(pthread_self()); +#endif /* _WIN32 */ +} + /** * Initializes the OpenSSL library. */ @@ -48,6 +57,7 @@ static void InitializeOpenSSL(void) l_Mutexes = new boost::mutex[CRYPTO_num_locks()]; CRYPTO_set_locking_callback(&OpenSSLLockingCallback); + CRYPTO_set_id_callback(&OpenSSLIDCallback); l_SSLInitialized = true; } @@ -66,7 +76,7 @@ shared_ptr MakeSSLContext(const String& pubkey, const String& privkey, shared_ptr sslContext = shared_ptr(SSL_CTX_new(TLSv1_method()), SSL_CTX_free); - SSL_CTX_set_mode(sslContext.get(), SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_AUTO_RETRY); + SSL_CTX_set_mode(sslContext.get(), SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); if (!SSL_CTX_use_certificate_chain_file(sslContext.get(), pubkey.CStr())) { BOOST_THROW_EXCEPTION(openssl_error()