From d1e87bdc4506b4bc10e2740ff3465c04b7ed0a49 Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Mon, 25 Feb 2019 16:40:14 +0100 Subject: [PATCH 1/5] Connect(): add non-async overload --- lib/base/tcpsocket.hpp | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/lib/base/tcpsocket.hpp b/lib/base/tcpsocket.hpp index 069288aad..0f8334f01 100644 --- a/lib/base/tcpsocket.hpp +++ b/lib/base/tcpsocket.hpp @@ -27,6 +27,35 @@ public: void Connect(const String& node, const String& service); }; +template +void Connect(Socket& socket, const String& node, const String& service) +{ + using boost::asio::ip::tcp; + + tcp::resolver resolver (socket.get_io_service()); + tcp::resolver::query query (node, service); + auto result (resolver.resolve(query)); + auto current (result.begin()); + + for (;;) { + try { + socket.open(current->endpoint().protocol()); + socket.set_option(tcp::socket::keep_alive(true)); + socket.connect(current->endpoint()); + + break; + } catch (const std::exception&) { + if (++current == result.end()) { + throw; + } + + if (socket.is_open()) { + socket.close(); + } + } + } +} + template void Connect(Socket& socket, const String& node, const String& service, boost::asio::yield_context yc) { From f2d9d91e8389e9a4fb2ddf2248c9ce603be8cbfd Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Mon, 25 Feb 2019 17:22:00 +0100 Subject: [PATCH 2/5] Introduce UnbufferedAsioTlsStream#GetPeerCertificate() --- lib/base/tlsstream.cpp | 5 +++++ lib/base/tlsstream.hpp | 1 + lib/remote/apilistener.cpp | 2 +- lib/remote/jsonrpcconnection-pki.cpp | 2 +- 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/base/tlsstream.cpp b/lib/base/tlsstream.cpp index 38913f28a..5f6fe33cf 100644 --- a/lib/base/tlsstream.cpp +++ b/lib/base/tlsstream.cpp @@ -465,6 +465,11 @@ String UnbufferedAsioTlsStream::GetVerifyError() const return m_VerifyError; } +std::shared_ptr UnbufferedAsioTlsStream::GetPeerCertificate() +{ + return std::shared_ptr(SSL_get_peer_certificate(native_handle()), X509_free); +} + void UnbufferedAsioTlsStream::BeforeHandshake(handshake_type type) { namespace ssl = boost::asio::ssl; diff --git a/lib/base/tlsstream.hpp b/lib/base/tlsstream.hpp index 6156a3d2f..3974d12b8 100644 --- a/lib/base/tlsstream.hpp +++ b/lib/base/tlsstream.hpp @@ -119,6 +119,7 @@ public: bool IsVerifyOK() const; String GetVerifyError() const; + std::shared_ptr GetPeerCertificate(); template inline diff --git a/lib/remote/apilistener.cpp b/lib/remote/apilistener.cpp index de6e754c0..c534c0969 100644 --- a/lib/remote/apilistener.cpp +++ b/lib/remote/apilistener.cpp @@ -523,7 +523,7 @@ void ApiListener::NewClientHandlerInternal(boost::asio::yield_context yc, const } }); - std::shared_ptr cert (SSL_get_peer_certificate(sslConn.native_handle()), X509_free); + std::shared_ptr cert (sslConn.GetPeerCertificate()); bool verify_ok = false; String identity; Endpoint::Ptr endpoint; diff --git a/lib/remote/jsonrpcconnection-pki.cpp b/lib/remote/jsonrpcconnection-pki.cpp index 66f88479b..2f66eb7b5 100644 --- a/lib/remote/jsonrpcconnection-pki.cpp +++ b/lib/remote/jsonrpcconnection-pki.cpp @@ -34,7 +34,7 @@ Value RequestCertificateHandler(const MessageOrigin::Ptr& origin, const Dictiona /* Use the presented client certificate if not provided. */ if (certText.IsEmpty()) { auto stream (origin->FromClient->GetStream()); - cert = std::shared_ptr(SSL_get_peer_certificate(stream->next_layer().native_handle()), X509_free); + cert = stream->next_layer().GetPeerCertificate(); } else { cert = StringToCertificate(certText); } From f4a78380e91f47ebe0a479eae9655e4fc269b85a Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Mon, 25 Feb 2019 18:12:32 +0100 Subject: [PATCH 3/5] Add non-async overloads for NetString::ReadStringFromStream() and NetString::WriteStringToStream() --- lib/base/netstring.cpp | 102 +++++++++++++++++++++++++++++++++++++++++ lib/base/netstring.hpp | 2 + 2 files changed, 104 insertions(+) diff --git a/lib/base/netstring.cpp b/lib/base/netstring.cpp index 489a8b40d..2be7675a7 100644 --- a/lib/base/netstring.cpp +++ b/lib/base/netstring.cpp @@ -118,6 +118,85 @@ size_t NetString::WriteStringToStream(const Stream::Ptr& stream, const String& s return msg.GetLength(); } +/** + * Reads data from a stream in netstring format. + * + * @param stream The stream to read from. + * @returns The String that has been read from the IOQueue. + * @exception invalid_argument The input stream is invalid. + * @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c + */ +String NetString::ReadStringFromStream(const std::shared_ptr& stream, + ssize_t maxMessageLength) +{ + namespace asio = boost::asio; + + size_t len = 0; + bool leadingZero = false; + + for (uint_fast8_t readBytes = 0;; ++readBytes) { + char byte = 0; + + { + asio::mutable_buffer byteBuf (&byte, 1); + asio::read(*stream, byteBuf); + } + + if (isdigit(byte)) { + if (readBytes == 9) { + BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters")); + } + + if (leadingZero) { + BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)")); + } + + len = len * 10u + size_t(byte - '0'); + + if (!readBytes && byte == '0') { + leadingZero = true; + } + } else if (byte == ':') { + if (!readBytes) { + BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)")); + } + + break; + } else { + BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)")); + } + } + + if (maxMessageLength >= 0 && len > maxMessageLength) { + std::stringstream errorMessage; + errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB"; + + BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str())); + } + + String payload; + + if (len) { + payload.Append(len, 0); + + asio::mutable_buffer payloadBuf (&*payload.Begin(), payload.GetLength()); + asio::read(*stream, payloadBuf); + } + + char trailer = 0; + + { + asio::mutable_buffer trailerBuf (&trailer, 1); + asio::read(*stream, trailerBuf); + } + + if (trailer != ',') { + BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)")); + } + + return std::move(payload); +} + /** * Reads data from a stream in netstring format. * @@ -197,6 +276,29 @@ String NetString::ReadStringFromStream(const std::shared_ptr& str return std::move(payload); } +/** + * Writes data into a stream using the netstring format and returns bytes written. + * + * @param stream The stream. + * @param str The String that is to be written. + * + * @return The amount of bytes written. + */ +size_t NetString::WriteStringToStream(const std::shared_ptr& stream, const String& str) +{ + namespace asio = boost::asio; + + std::ostringstream msgbuf; + WriteStringToStream(msgbuf, str); + + String msg = msgbuf.str(); + asio::const_buffer msgBuf (msg.CStr(), msg.GetLength()); + + asio::write(*stream, msgBuf); + + return msg.GetLength(); +} + /** * Writes data into a stream using the netstring format and returns bytes written. * diff --git a/lib/base/netstring.hpp b/lib/base/netstring.hpp index f84eac7a3..2d2435907 100644 --- a/lib/base/netstring.hpp +++ b/lib/base/netstring.hpp @@ -26,9 +26,11 @@ class NetString public: static StreamReadStatus ReadStringFromStream(const Stream::Ptr& stream, String *message, StreamReadContext& context, bool may_wait = false, ssize_t maxMessageLength = -1); + static String ReadStringFromStream(const std::shared_ptr& stream, ssize_t maxMessageLength = -1); static String ReadStringFromStream(const std::shared_ptr& stream, boost::asio::yield_context yc, ssize_t maxMessageLength = -1); static size_t WriteStringToStream(const Stream::Ptr& stream, const String& message); + static size_t WriteStringToStream(const std::shared_ptr& stream, const String& message); static size_t WriteStringToStream(const std::shared_ptr& stream, const String& message, boost::asio::yield_context yc); static void WriteStringToStream(std::ostream& stream, const String& message); From 6e7932f157e83db4831db617a207052084dc9cda Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Mon, 25 Feb 2019 18:15:47 +0100 Subject: [PATCH 4/5] Add non-async overloads for JsonRpc::ReadMessage() and JsonRpc::SendMessage() --- lib/remote/jsonrpc.cpp | 31 +++++++++++++++++++++++++++++++ lib/remote/jsonrpc.hpp | 2 ++ 2 files changed, 33 insertions(+) diff --git a/lib/remote/jsonrpc.cpp b/lib/remote/jsonrpc.cpp index 03f3c7d0e..63bc5ff85 100644 --- a/lib/remote/jsonrpc.cpp +++ b/lib/remote/jsonrpc.cpp @@ -59,6 +59,25 @@ size_t JsonRpc::SendMessage(const Stream::Ptr& stream, const Dictionary::Ptr& me return NetString::WriteStringToStream(stream, json); } +/** + * Sends a message to the connected peer and returns the bytes sent. + * + * @param message The message. + * + * @return The amount of bytes sent. + */ +size_t JsonRpc::SendMessage(const std::shared_ptr& stream, const Dictionary::Ptr& message) +{ + String json = JsonEncode(message); + +#ifdef I2_DEBUG + if (GetDebugJsonRpcCached()) + std::cerr << ConsoleColorTag(Console_ForegroundBlue) << ">> " << json << ConsoleColorTag(Console_Normal) << "\n"; +#endif /* I2_DEBUG */ + + return NetString::WriteStringToStream(stream, json); +} + /** * Sends a message to the connected peer and returns the bytes sent. * @@ -106,6 +125,18 @@ StreamReadStatus JsonRpc::ReadMessage(const Stream::Ptr& stream, String *message return StatusNewItem; } +String JsonRpc::ReadMessage(const std::shared_ptr& stream, ssize_t maxMessageLength) +{ + String jsonString = NetString::ReadStringFromStream(stream, maxMessageLength); + +#ifdef I2_DEBUG + if (GetDebugJsonRpcCached()) + std::cerr << ConsoleColorTag(Console_ForegroundBlue) << "<< " << jsonString << ConsoleColorTag(Console_Normal) << "\n"; +#endif /* I2_DEBUG */ + + return std::move(jsonString); +} + String JsonRpc::ReadMessage(const std::shared_ptr& stream, boost::asio::yield_context yc, ssize_t maxMessageLength) { String jsonString = NetString::ReadStringFromStream(stream, yc, maxMessageLength); diff --git a/lib/remote/jsonrpc.hpp b/lib/remote/jsonrpc.hpp index faf9c07e8..98187fe6c 100644 --- a/lib/remote/jsonrpc.hpp +++ b/lib/remote/jsonrpc.hpp @@ -22,9 +22,11 @@ class JsonRpc { public: static size_t SendMessage(const Stream::Ptr& stream, const Dictionary::Ptr& message); + static size_t SendMessage(const std::shared_ptr& stream, const Dictionary::Ptr& message); static size_t SendMessage(const std::shared_ptr& stream, const Dictionary::Ptr& message, boost::asio::yield_context yc); static size_t SendRawMessage(const std::shared_ptr& stream, const String& json, boost::asio::yield_context yc); static StreamReadStatus ReadMessage(const Stream::Ptr& stream, String *message, StreamReadContext& src, bool may_wait = false, ssize_t maxMessageLength = -1); + static String ReadMessage(const std::shared_ptr& stream, ssize_t maxMessageLength = -1); static String ReadMessage(const std::shared_ptr& stream, boost::asio::yield_context yc, ssize_t maxMessageLength = -1); static Dictionary::Ptr DecodeMessage(const String& message); From 00d859234e64f76c0e3bc4da20be71e0d7e3bbe5 Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Mon, 25 Feb 2019 18:58:04 +0100 Subject: [PATCH 5/5] Use new I/O engine in PkiUtility::FetchCert() and PkiUtility::RequestCertificate() --- lib/remote/pkiutility.cpp | 118 ++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 57 deletions(-) diff --git a/lib/remote/pkiutility.cpp b/lib/remote/pkiutility.cpp index e1e785288..c08989dd8 100644 --- a/lib/remote/pkiutility.cpp +++ b/lib/remote/pkiutility.cpp @@ -2,8 +2,11 @@ #include "remote/pkiutility.hpp" #include "remote/apilistener.hpp" +#include "base/defer.hpp" +#include "base/io-engine.hpp" #include "base/logger.hpp" #include "base/application.hpp" +#include "base/tcpsocket.hpp" #include "base/tlsutility.hpp" #include "base/console.hpp" #include "base/tlsstream.hpp" @@ -14,6 +17,7 @@ #include "remote/jsonrpc.hpp" #include #include +#include using namespace icinga; @@ -76,22 +80,10 @@ int PkiUtility::SignCsr(const String& csrfile, const String& certfile) std::shared_ptr PkiUtility::FetchCert(const String& host, const String& port) { - TcpSocket::Ptr client = new TcpSocket(); + std::shared_ptr sslContext; try { - client->Connect(host, port); - } catch (const std::exception& ex) { - Log(LogCritical, "pki") - << "Cannot connect to host '" << host << "' on port '" << port << "'"; - Log(LogDebug, "pki") - << "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex); - return std::shared_ptr(); - } - - std::shared_ptr sslContext; - - try { - sslContext = MakeSSLContext(); + sslContext = MakeAsioSslContext(); } catch (const std::exception& ex) { Log(LogCritical, "pki") << "Cannot make SSL context."; @@ -100,17 +92,31 @@ std::shared_ptr PkiUtility::FetchCert(const String& host, const String& po return std::shared_ptr(); } - TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext); + auto stream (std::make_shared(IoEngine::Get().GetIoService(), *sslContext, host)); try { - stream->Handshake(); + Connect(stream->lowest_layer(), host, port); + } catch (const std::exception& ex) { + Log(LogCritical, "pki") + << "Cannot connect to host '" << host << "' on port '" << port << "'"; + Log(LogDebug, "pki") + << "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex); + return std::shared_ptr(); + } + + auto& sslConn (stream->next_layer()); + + try { + sslConn.handshake(sslConn.client); } catch (const std::exception& ex) { Log(LogCritical, "pki") << "Client TLS handshake failed. (" << ex.what() << ")"; return std::shared_ptr(); } - return stream->GetPeerCertificate(); + Defer shutdown ([&sslConn]() { sslConn.shutdown(); }); + + return sslConn.GetPeerCertificate(); } int PkiUtility::WriteCert(const std::shared_ptr& cert, const String& trustedfile) @@ -142,22 +148,10 @@ int PkiUtility::GenTicket(const String& cn, const String& salt, std::ostream& ti int PkiUtility::RequestCertificate(const String& host, const String& port, const String& keyfile, const String& certfile, const String& cafile, const std::shared_ptr& trustedCert, const String& ticket) { - TcpSocket::Ptr client = new TcpSocket(); + std::shared_ptr sslContext; try { - client->Connect(host, port); - } catch (const std::exception& ex) { - Log(LogCritical, "cli") - << "Cannot connect to host '" << host << "' on port '" << port << "'"; - Log(LogDebug, "cli") - << "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex); - return 1; - } - - std::shared_ptr sslContext; - - try { - sslContext = MakeSSLContext(certfile, keyfile); + sslContext = MakeAsioSslContext(certfile, keyfile); } catch (const std::exception& ex) { Log(LogCritical, "cli") << "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "'."; @@ -166,17 +160,31 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const return 1; } - TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext); + auto stream (std::make_shared(IoEngine::Get().GetIoService(), *sslContext, host)); try { - stream->Handshake(); + Connect(stream->lowest_layer(), host, port); + } catch (const std::exception& ex) { + Log(LogCritical, "cli") + << "Cannot connect to host '" << host << "' on port '" << port << "'"; + Log(LogDebug, "cli") + << "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex); + return 1; + } + + auto& sslConn (stream->next_layer()); + + try { + sslConn.handshake(sslConn.client); } catch (const std::exception& ex) { Log(LogCritical, "cli") << "Client TLS handshake failed: " << DiagnosticInformation(ex, false); return 1; } - std::shared_ptr peerCert = stream->GetPeerCertificate(); + Defer shutdown ([&sslConn]() { sslConn.shutdown(); }); + + auto peerCert (sslConn.GetPeerCertificate()); if (X509_cmp(peerCert.get(), trustedCert.get())) { Log(LogCritical, "cli", "Peer certificate does not match trusted certificate."); @@ -196,36 +204,32 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const { "params", params } }); - JsonRpc::SendMessage(stream, request); - - String jsonString; Dictionary::Ptr response; - StreamReadContext src; - for (;;) { - StreamReadStatus srs = JsonRpc::ReadMessage(stream, &jsonString, src); + try { + JsonRpc::SendMessage(stream, request); + stream->flush(); - if (srs == StatusEof) - break; + for (;;) { + response = JsonRpc::DecodeMessage(JsonRpc::ReadMessage(stream)); - if (srs != StatusNewItem) - continue; - - response = JsonRpc::DecodeMessage(jsonString); - - if (response && response->Contains("error")) { - Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug)."); + if (response && response->Contains("error")) { + Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug)."); #ifdef I2_DEBUG - /* we shouldn't expose master errors to the user in production environments */ - Log(LogCritical, "cli", response->Get("error")); + /* we shouldn't expose master errors to the user in production environments */ + Log(LogCritical, "cli", response->Get("error")); #endif /* I2_DEBUG */ - return 1; + return 1; + } + + if (response && (response->Get("id") != msgid)) + continue; + + break; } - - if (response && (response->Get("id") != msgid)) - continue; - - break; + } catch (...) { + Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log."); + return 1; } if (!response) {