Merge pull request #7133 from Icinga/feature/boost-asio-pki

Use new I/O engine in PkiUtility::FetchCert() and PkiUtility::RequestCertificate()
This commit is contained in:
Michael Friedrich 2019-04-23 14:27:48 +02:00 committed by GitHub
commit 0f804d126b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 235 additions and 59 deletions

View File

@ -118,6 +118,85 @@ size_t NetString::WriteStringToStream(const Stream::Ptr& stream, const String& s
return msg.GetLength(); return msg.GetLength();
} }
/**
* Reads data from a stream in netstring format.
*
* @param stream The stream to read from.
* @returns The String that has been read from the IOQueue.
* @exception invalid_argument The input stream is invalid.
* @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
*/
String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
ssize_t maxMessageLength)
{
namespace asio = boost::asio;
size_t len = 0;
bool leadingZero = false;
for (uint_fast8_t readBytes = 0;; ++readBytes) {
char byte = 0;
{
asio::mutable_buffer byteBuf (&byte, 1);
asio::read(*stream, byteBuf);
}
if (isdigit(byte)) {
if (readBytes == 9) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
}
if (leadingZero) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
}
len = len * 10u + size_t(byte - '0');
if (!readBytes && byte == '0') {
leadingZero = true;
}
} else if (byte == ':') {
if (!readBytes) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
}
break;
} else {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
}
}
if (maxMessageLength >= 0 && len > maxMessageLength) {
std::stringstream errorMessage;
errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";
BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
}
String payload;
if (len) {
payload.Append(len, 0);
asio::mutable_buffer payloadBuf (&*payload.Begin(), payload.GetLength());
asio::read(*stream, payloadBuf);
}
char trailer = 0;
{
asio::mutable_buffer trailerBuf (&trailer, 1);
asio::read(*stream, trailerBuf);
}
if (trailer != ',') {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
}
return std::move(payload);
}
/** /**
* Reads data from a stream in netstring format. * Reads data from a stream in netstring format.
* *
@ -197,6 +276,29 @@ String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& str
return std::move(payload); return std::move(payload);
} }
/**
* Writes data into a stream using the netstring format and returns bytes written.
*
* @param stream The stream.
* @param str The String that is to be written.
*
* @return The amount of bytes written.
*/
size_t NetString::WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& str)
{
namespace asio = boost::asio;
std::ostringstream msgbuf;
WriteStringToStream(msgbuf, str);
String msg = msgbuf.str();
asio::const_buffer msgBuf (msg.CStr(), msg.GetLength());
asio::write(*stream, msgBuf);
return msg.GetLength();
}
/** /**
* Writes data into a stream using the netstring format and returns bytes written. * Writes data into a stream using the netstring format and returns bytes written.
* *

View File

@ -26,9 +26,11 @@ class NetString
public: public:
static StreamReadStatus ReadStringFromStream(const Stream::Ptr& stream, String *message, StreamReadContext& context, static StreamReadStatus ReadStringFromStream(const Stream::Ptr& stream, String *message, StreamReadContext& context,
bool may_wait = false, ssize_t maxMessageLength = -1); bool may_wait = false, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream, static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
boost::asio::yield_context yc, ssize_t maxMessageLength = -1); boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
static size_t WriteStringToStream(const Stream::Ptr& stream, const String& message); static size_t WriteStringToStream(const Stream::Ptr& stream, const String& message);
static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message);
static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message, boost::asio::yield_context yc); static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message, boost::asio::yield_context yc);
static void WriteStringToStream(std::ostream& stream, const String& message); static void WriteStringToStream(std::ostream& stream, const String& message);

View File

@ -27,6 +27,35 @@ public:
void Connect(const String& node, const String& service); void Connect(const String& node, const String& service);
}; };
template<class Socket>
void Connect(Socket& socket, const String& node, const String& service)
{
using boost::asio::ip::tcp;
tcp::resolver resolver (socket.get_io_service());
tcp::resolver::query query (node, service);
auto result (resolver.resolve(query));
auto current (result.begin());
for (;;) {
try {
socket.open(current->endpoint().protocol());
socket.set_option(tcp::socket::keep_alive(true));
socket.connect(current->endpoint());
break;
} catch (const std::exception&) {
if (++current == result.end()) {
throw;
}
if (socket.is_open()) {
socket.close();
}
}
}
}
template<class Socket> template<class Socket>
void Connect(Socket& socket, const String& node, const String& service, boost::asio::yield_context yc) void Connect(Socket& socket, const String& node, const String& service, boost::asio::yield_context yc)
{ {

View File

@ -465,6 +465,11 @@ String UnbufferedAsioTlsStream::GetVerifyError() const
return m_VerifyError; return m_VerifyError;
} }
std::shared_ptr<X509> UnbufferedAsioTlsStream::GetPeerCertificate()
{
return std::shared_ptr<X509>(SSL_get_peer_certificate(native_handle()), X509_free);
}
void UnbufferedAsioTlsStream::BeforeHandshake(handshake_type type) void UnbufferedAsioTlsStream::BeforeHandshake(handshake_type type)
{ {
namespace ssl = boost::asio::ssl; namespace ssl = boost::asio::ssl;

View File

@ -119,6 +119,7 @@ public:
bool IsVerifyOK() const; bool IsVerifyOK() const;
String GetVerifyError() const; String GetVerifyError() const;
std::shared_ptr<X509> GetPeerCertificate();
template<class... Args> template<class... Args>
inline inline

View File

@ -524,7 +524,7 @@ void ApiListener::NewClientHandlerInternal(boost::asio::yield_context yc, const
} }
}); });
std::shared_ptr<X509> cert (SSL_get_peer_certificate(sslConn.native_handle()), X509_free); std::shared_ptr<X509> cert (sslConn.GetPeerCertificate());
bool verify_ok = false; bool verify_ok = false;
String identity; String identity;
Endpoint::Ptr endpoint; Endpoint::Ptr endpoint;

View File

@ -59,6 +59,25 @@ size_t JsonRpc::SendMessage(const Stream::Ptr& stream, const Dictionary::Ptr& me
return NetString::WriteStringToStream(stream, json); return NetString::WriteStringToStream(stream, json);
} }
/**
* Sends a message to the connected peer and returns the bytes sent.
*
* @param message The message.
*
* @return The amount of bytes sent.
*/
size_t JsonRpc::SendMessage(const std::shared_ptr<AsioTlsStream>& stream, const Dictionary::Ptr& message)
{
String json = JsonEncode(message);
#ifdef I2_DEBUG
if (GetDebugJsonRpcCached())
std::cerr << ConsoleColorTag(Console_ForegroundBlue) << ">> " << json << ConsoleColorTag(Console_Normal) << "\n";
#endif /* I2_DEBUG */
return NetString::WriteStringToStream(stream, json);
}
/** /**
* Sends a message to the connected peer and returns the bytes sent. * Sends a message to the connected peer and returns the bytes sent.
* *
@ -106,6 +125,18 @@ StreamReadStatus JsonRpc::ReadMessage(const Stream::Ptr& stream, String *message
return StatusNewItem; return StatusNewItem;
} }
String JsonRpc::ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, ssize_t maxMessageLength)
{
String jsonString = NetString::ReadStringFromStream(stream, maxMessageLength);
#ifdef I2_DEBUG
if (GetDebugJsonRpcCached())
std::cerr << ConsoleColorTag(Console_ForegroundBlue) << "<< " << jsonString << ConsoleColorTag(Console_Normal) << "\n";
#endif /* I2_DEBUG */
return std::move(jsonString);
}
String JsonRpc::ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, boost::asio::yield_context yc, ssize_t maxMessageLength) String JsonRpc::ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, boost::asio::yield_context yc, ssize_t maxMessageLength)
{ {
String jsonString = NetString::ReadStringFromStream(stream, yc, maxMessageLength); String jsonString = NetString::ReadStringFromStream(stream, yc, maxMessageLength);

View File

@ -22,9 +22,11 @@ class JsonRpc
{ {
public: public:
static size_t SendMessage(const Stream::Ptr& stream, const Dictionary::Ptr& message); static size_t SendMessage(const Stream::Ptr& stream, const Dictionary::Ptr& message);
static size_t SendMessage(const std::shared_ptr<AsioTlsStream>& stream, const Dictionary::Ptr& message);
static size_t SendMessage(const std::shared_ptr<AsioTlsStream>& stream, const Dictionary::Ptr& message, boost::asio::yield_context yc); static size_t SendMessage(const std::shared_ptr<AsioTlsStream>& stream, const Dictionary::Ptr& message, boost::asio::yield_context yc);
static size_t SendRawMessage(const std::shared_ptr<AsioTlsStream>& stream, const String& json, boost::asio::yield_context yc); static size_t SendRawMessage(const std::shared_ptr<AsioTlsStream>& stream, const String& json, boost::asio::yield_context yc);
static StreamReadStatus ReadMessage(const Stream::Ptr& stream, String *message, StreamReadContext& src, bool may_wait = false, ssize_t maxMessageLength = -1); static StreamReadStatus ReadMessage(const Stream::Ptr& stream, String *message, StreamReadContext& src, bool may_wait = false, ssize_t maxMessageLength = -1);
static String ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, ssize_t maxMessageLength = -1);
static String ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, boost::asio::yield_context yc, ssize_t maxMessageLength = -1); static String ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
static Dictionary::Ptr DecodeMessage(const String& message); static Dictionary::Ptr DecodeMessage(const String& message);

View File

@ -34,7 +34,7 @@ Value RequestCertificateHandler(const MessageOrigin::Ptr& origin, const Dictiona
/* Use the presented client certificate if not provided. */ /* Use the presented client certificate if not provided. */
if (certText.IsEmpty()) { if (certText.IsEmpty()) {
auto stream (origin->FromClient->GetStream()); auto stream (origin->FromClient->GetStream());
cert = std::shared_ptr<X509>(SSL_get_peer_certificate(stream->next_layer().native_handle()), X509_free); cert = stream->next_layer().GetPeerCertificate();
} else { } else {
cert = StringToCertificate(certText); cert = StringToCertificate(certText);
} }

View File

@ -2,8 +2,11 @@
#include "remote/pkiutility.hpp" #include "remote/pkiutility.hpp"
#include "remote/apilistener.hpp" #include "remote/apilistener.hpp"
#include "base/defer.hpp"
#include "base/io-engine.hpp"
#include "base/logger.hpp" #include "base/logger.hpp"
#include "base/application.hpp" #include "base/application.hpp"
#include "base/tcpsocket.hpp"
#include "base/tlsutility.hpp" #include "base/tlsutility.hpp"
#include "base/console.hpp" #include "base/console.hpp"
#include "base/tlsstream.hpp" #include "base/tlsstream.hpp"
@ -14,6 +17,7 @@
#include "remote/jsonrpc.hpp" #include "remote/jsonrpc.hpp"
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <boost/asio/ssl/context.hpp>
using namespace icinga; using namespace icinga;
@ -76,22 +80,10 @@ int PkiUtility::SignCsr(const String& csrfile, const String& certfile)
std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& port) std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& port)
{ {
TcpSocket::Ptr client = new TcpSocket(); std::shared_ptr<boost::asio::ssl::context> sslContext;
try { try {
client->Connect(host, port); sslContext = MakeAsioSslContext();
} catch (const std::exception& ex) {
Log(LogCritical, "pki")
<< "Cannot connect to host '" << host << "' on port '" << port << "'";
Log(LogDebug, "pki")
<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
return std::shared_ptr<X509>();
}
std::shared_ptr<SSL_CTX> sslContext;
try {
sslContext = MakeSSLContext();
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
Log(LogCritical, "pki") Log(LogCritical, "pki")
<< "Cannot make SSL context."; << "Cannot make SSL context.";
@ -100,17 +92,31 @@ std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& po
return std::shared_ptr<X509>(); return std::shared_ptr<X509>();
} }
TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext); auto stream (std::make_shared<AsioTlsStream>(IoEngine::Get().GetIoService(), *sslContext, host));
try { try {
stream->Handshake(); Connect(stream->lowest_layer(), host, port);
} catch (const std::exception& ex) {
Log(LogCritical, "pki")
<< "Cannot connect to host '" << host << "' on port '" << port << "'";
Log(LogDebug, "pki")
<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
return std::shared_ptr<X509>();
}
auto& sslConn (stream->next_layer());
try {
sslConn.handshake(sslConn.client);
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
Log(LogCritical, "pki") Log(LogCritical, "pki")
<< "Client TLS handshake failed. (" << ex.what() << ")"; << "Client TLS handshake failed. (" << ex.what() << ")";
return std::shared_ptr<X509>(); return std::shared_ptr<X509>();
} }
return stream->GetPeerCertificate(); Defer shutdown ([&sslConn]() { sslConn.shutdown(); });
return sslConn.GetPeerCertificate();
} }
int PkiUtility::WriteCert(const std::shared_ptr<X509>& cert, const String& trustedfile) int PkiUtility::WriteCert(const std::shared_ptr<X509>& cert, const String& trustedfile)
@ -142,22 +148,10 @@ int PkiUtility::GenTicket(const String& cn, const String& salt, std::ostream& ti
int PkiUtility::RequestCertificate(const String& host, const String& port, const String& keyfile, int PkiUtility::RequestCertificate(const String& host, const String& port, const String& keyfile,
const String& certfile, const String& cafile, const std::shared_ptr<X509>& trustedCert, const String& ticket) const String& certfile, const String& cafile, const std::shared_ptr<X509>& trustedCert, const String& ticket)
{ {
TcpSocket::Ptr client = new TcpSocket(); std::shared_ptr<boost::asio::ssl::context> sslContext;
try { try {
client->Connect(host, port); sslContext = MakeAsioSslContext(certfile, keyfile);
} catch (const std::exception& ex) {
Log(LogCritical, "cli")
<< "Cannot connect to host '" << host << "' on port '" << port << "'";
Log(LogDebug, "cli")
<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
return 1;
}
std::shared_ptr<SSL_CTX> sslContext;
try {
sslContext = MakeSSLContext(certfile, keyfile);
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
Log(LogCritical, "cli") Log(LogCritical, "cli")
<< "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "'."; << "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "'.";
@ -166,17 +160,31 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const
return 1; return 1;
} }
TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext); auto stream (std::make_shared<AsioTlsStream>(IoEngine::Get().GetIoService(), *sslContext, host));
try { try {
stream->Handshake(); Connect(stream->lowest_layer(), host, port);
} catch (const std::exception& ex) {
Log(LogCritical, "cli")
<< "Cannot connect to host '" << host << "' on port '" << port << "'";
Log(LogDebug, "cli")
<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
return 1;
}
auto& sslConn (stream->next_layer());
try {
sslConn.handshake(sslConn.client);
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
Log(LogCritical, "cli") Log(LogCritical, "cli")
<< "Client TLS handshake failed: " << DiagnosticInformation(ex, false); << "Client TLS handshake failed: " << DiagnosticInformation(ex, false);
return 1; return 1;
} }
std::shared_ptr<X509> peerCert = stream->GetPeerCertificate(); Defer shutdown ([&sslConn]() { sslConn.shutdown(); });
auto peerCert (sslConn.GetPeerCertificate());
if (X509_cmp(peerCert.get(), trustedCert.get())) { if (X509_cmp(peerCert.get(), trustedCert.get())) {
Log(LogCritical, "cli", "Peer certificate does not match trusted certificate."); Log(LogCritical, "cli", "Peer certificate does not match trusted certificate.");
@ -196,36 +204,32 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const
{ "params", params } { "params", params }
}); });
JsonRpc::SendMessage(stream, request);
String jsonString;
Dictionary::Ptr response; Dictionary::Ptr response;
StreamReadContext src;
for (;;) { try {
StreamReadStatus srs = JsonRpc::ReadMessage(stream, &jsonString, src); JsonRpc::SendMessage(stream, request);
stream->flush();
if (srs == StatusEof) for (;;) {
break; response = JsonRpc::DecodeMessage(JsonRpc::ReadMessage(stream));
if (srs != StatusNewItem) if (response && response->Contains("error")) {
continue; Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug).");
response = JsonRpc::DecodeMessage(jsonString);
if (response && response->Contains("error")) {
Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug).");
#ifdef I2_DEBUG #ifdef I2_DEBUG
/* we shouldn't expose master errors to the user in production environments */ /* we shouldn't expose master errors to the user in production environments */
Log(LogCritical, "cli", response->Get("error")); Log(LogCritical, "cli", response->Get("error"));
#endif /* I2_DEBUG */ #endif /* I2_DEBUG */
return 1; return 1;
}
if (response && (response->Get("id") != msgid))
continue;
break;
} }
} catch (...) {
if (response && (response->Get("id") != msgid)) Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log.");
continue; return 1;
break;
} }
if (!response) { if (!response) {