Merge pull request #7133 from Icinga/feature/boost-asio-pki

Use new I/O engine in PkiUtility::FetchCert() and PkiUtility::RequestCertificate()
This commit is contained in:
Michael Friedrich 2019-04-23 14:27:48 +02:00 committed by GitHub
commit 0f804d126b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 235 additions and 59 deletions

View File

@ -118,6 +118,85 @@ size_t NetString::WriteStringToStream(const Stream::Ptr& stream, const String& s
return msg.GetLength();
}
/**
* Reads data from a stream in netstring format.
*
* @param stream The stream to read from.
* @returns The String that has been read from the IOQueue.
* @exception invalid_argument The input stream is invalid.
* @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
*/
String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
ssize_t maxMessageLength)
{
namespace asio = boost::asio;
size_t len = 0;
bool leadingZero = false;
for (uint_fast8_t readBytes = 0;; ++readBytes) {
char byte = 0;
{
asio::mutable_buffer byteBuf (&byte, 1);
asio::read(*stream, byteBuf);
}
if (isdigit(byte)) {
if (readBytes == 9) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Length specifier must not exceed 9 characters"));
}
if (leadingZero) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (leading zero)"));
}
len = len * 10u + size_t(byte - '0');
if (!readBytes && byte == '0') {
leadingZero = true;
}
} else if (byte == ':') {
if (!readBytes) {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (no length specifier)"));
}
break;
} else {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing :)"));
}
}
if (maxMessageLength >= 0 && len > maxMessageLength) {
std::stringstream errorMessage;
errorMessage << "Max data length exceeded: " << (maxMessageLength / 1024) << " KB";
BOOST_THROW_EXCEPTION(std::invalid_argument(errorMessage.str()));
}
String payload;
if (len) {
payload.Append(len, 0);
asio::mutable_buffer payloadBuf (&*payload.Begin(), payload.GetLength());
asio::read(*stream, payloadBuf);
}
char trailer = 0;
{
asio::mutable_buffer trailerBuf (&trailer, 1);
asio::read(*stream, trailerBuf);
}
if (trailer != ',') {
BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid NetString (missing ,)"));
}
return std::move(payload);
}
/**
* Reads data from a stream in netstring format.
*
@ -197,6 +276,29 @@ String NetString::ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& str
return std::move(payload);
}
/**
* Writes data into a stream using the netstring format and returns bytes written.
*
* @param stream The stream.
* @param str The String that is to be written.
*
* @return The amount of bytes written.
*/
size_t NetString::WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& str)
{
namespace asio = boost::asio;
std::ostringstream msgbuf;
WriteStringToStream(msgbuf, str);
String msg = msgbuf.str();
asio::const_buffer msgBuf (msg.CStr(), msg.GetLength());
asio::write(*stream, msgBuf);
return msg.GetLength();
}
/**
* Writes data into a stream using the netstring format and returns bytes written.
*

View File

@ -26,9 +26,11 @@ class NetString
public:
static StreamReadStatus ReadStringFromStream(const Stream::Ptr& stream, String *message, StreamReadContext& context,
bool may_wait = false, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const std::shared_ptr<AsioTlsStream>& stream,
boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
static size_t WriteStringToStream(const Stream::Ptr& stream, const String& message);
static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message);
static size_t WriteStringToStream(const std::shared_ptr<AsioTlsStream>& stream, const String& message, boost::asio::yield_context yc);
static void WriteStringToStream(std::ostream& stream, const String& message);

View File

@ -27,6 +27,35 @@ public:
void Connect(const String& node, const String& service);
};
template<class Socket>
void Connect(Socket& socket, const String& node, const String& service)
{
using boost::asio::ip::tcp;
tcp::resolver resolver (socket.get_io_service());
tcp::resolver::query query (node, service);
auto result (resolver.resolve(query));
auto current (result.begin());
for (;;) {
try {
socket.open(current->endpoint().protocol());
socket.set_option(tcp::socket::keep_alive(true));
socket.connect(current->endpoint());
break;
} catch (const std::exception&) {
if (++current == result.end()) {
throw;
}
if (socket.is_open()) {
socket.close();
}
}
}
}
template<class Socket>
void Connect(Socket& socket, const String& node, const String& service, boost::asio::yield_context yc)
{

View File

@ -465,6 +465,11 @@ String UnbufferedAsioTlsStream::GetVerifyError() const
return m_VerifyError;
}
std::shared_ptr<X509> UnbufferedAsioTlsStream::GetPeerCertificate()
{
return std::shared_ptr<X509>(SSL_get_peer_certificate(native_handle()), X509_free);
}
void UnbufferedAsioTlsStream::BeforeHandshake(handshake_type type)
{
namespace ssl = boost::asio::ssl;

View File

@ -119,6 +119,7 @@ public:
bool IsVerifyOK() const;
String GetVerifyError() const;
std::shared_ptr<X509> GetPeerCertificate();
template<class... Args>
inline

View File

@ -524,7 +524,7 @@ void ApiListener::NewClientHandlerInternal(boost::asio::yield_context yc, const
}
});
std::shared_ptr<X509> cert (SSL_get_peer_certificate(sslConn.native_handle()), X509_free);
std::shared_ptr<X509> cert (sslConn.GetPeerCertificate());
bool verify_ok = false;
String identity;
Endpoint::Ptr endpoint;

View File

@ -59,6 +59,25 @@ size_t JsonRpc::SendMessage(const Stream::Ptr& stream, const Dictionary::Ptr& me
return NetString::WriteStringToStream(stream, json);
}
/**
* Sends a message to the connected peer and returns the bytes sent.
*
* @param message The message.
*
* @return The amount of bytes sent.
*/
size_t JsonRpc::SendMessage(const std::shared_ptr<AsioTlsStream>& stream, const Dictionary::Ptr& message)
{
String json = JsonEncode(message);
#ifdef I2_DEBUG
if (GetDebugJsonRpcCached())
std::cerr << ConsoleColorTag(Console_ForegroundBlue) << ">> " << json << ConsoleColorTag(Console_Normal) << "\n";
#endif /* I2_DEBUG */
return NetString::WriteStringToStream(stream, json);
}
/**
* Sends a message to the connected peer and returns the bytes sent.
*
@ -106,6 +125,18 @@ StreamReadStatus JsonRpc::ReadMessage(const Stream::Ptr& stream, String *message
return StatusNewItem;
}
String JsonRpc::ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, ssize_t maxMessageLength)
{
String jsonString = NetString::ReadStringFromStream(stream, maxMessageLength);
#ifdef I2_DEBUG
if (GetDebugJsonRpcCached())
std::cerr << ConsoleColorTag(Console_ForegroundBlue) << "<< " << jsonString << ConsoleColorTag(Console_Normal) << "\n";
#endif /* I2_DEBUG */
return std::move(jsonString);
}
String JsonRpc::ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, boost::asio::yield_context yc, ssize_t maxMessageLength)
{
String jsonString = NetString::ReadStringFromStream(stream, yc, maxMessageLength);

View File

@ -22,9 +22,11 @@ class JsonRpc
{
public:
static size_t SendMessage(const Stream::Ptr& stream, const Dictionary::Ptr& message);
static size_t SendMessage(const std::shared_ptr<AsioTlsStream>& stream, const Dictionary::Ptr& message);
static size_t SendMessage(const std::shared_ptr<AsioTlsStream>& stream, const Dictionary::Ptr& message, boost::asio::yield_context yc);
static size_t SendRawMessage(const std::shared_ptr<AsioTlsStream>& stream, const String& json, boost::asio::yield_context yc);
static StreamReadStatus ReadMessage(const Stream::Ptr& stream, String *message, StreamReadContext& src, bool may_wait = false, ssize_t maxMessageLength = -1);
static String ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, ssize_t maxMessageLength = -1);
static String ReadMessage(const std::shared_ptr<AsioTlsStream>& stream, boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
static Dictionary::Ptr DecodeMessage(const String& message);

View File

@ -34,7 +34,7 @@ Value RequestCertificateHandler(const MessageOrigin::Ptr& origin, const Dictiona
/* Use the presented client certificate if not provided. */
if (certText.IsEmpty()) {
auto stream (origin->FromClient->GetStream());
cert = std::shared_ptr<X509>(SSL_get_peer_certificate(stream->next_layer().native_handle()), X509_free);
cert = stream->next_layer().GetPeerCertificate();
} else {
cert = StringToCertificate(certText);
}

View File

@ -2,8 +2,11 @@
#include "remote/pkiutility.hpp"
#include "remote/apilistener.hpp"
#include "base/defer.hpp"
#include "base/io-engine.hpp"
#include "base/logger.hpp"
#include "base/application.hpp"
#include "base/tcpsocket.hpp"
#include "base/tlsutility.hpp"
#include "base/console.hpp"
#include "base/tlsstream.hpp"
@ -14,6 +17,7 @@
#include "remote/jsonrpc.hpp"
#include <fstream>
#include <iostream>
#include <boost/asio/ssl/context.hpp>
using namespace icinga;
@ -76,22 +80,10 @@ int PkiUtility::SignCsr(const String& csrfile, const String& certfile)
std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& port)
{
TcpSocket::Ptr client = new TcpSocket();
std::shared_ptr<boost::asio::ssl::context> sslContext;
try {
client->Connect(host, port);
} catch (const std::exception& ex) {
Log(LogCritical, "pki")
<< "Cannot connect to host '" << host << "' on port '" << port << "'";
Log(LogDebug, "pki")
<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
return std::shared_ptr<X509>();
}
std::shared_ptr<SSL_CTX> sslContext;
try {
sslContext = MakeSSLContext();
sslContext = MakeAsioSslContext();
} catch (const std::exception& ex) {
Log(LogCritical, "pki")
<< "Cannot make SSL context.";
@ -100,17 +92,31 @@ std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& po
return std::shared_ptr<X509>();
}
TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext);
auto stream (std::make_shared<AsioTlsStream>(IoEngine::Get().GetIoService(), *sslContext, host));
try {
stream->Handshake();
Connect(stream->lowest_layer(), host, port);
} catch (const std::exception& ex) {
Log(LogCritical, "pki")
<< "Cannot connect to host '" << host << "' on port '" << port << "'";
Log(LogDebug, "pki")
<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
return std::shared_ptr<X509>();
}
auto& sslConn (stream->next_layer());
try {
sslConn.handshake(sslConn.client);
} catch (const std::exception& ex) {
Log(LogCritical, "pki")
<< "Client TLS handshake failed. (" << ex.what() << ")";
return std::shared_ptr<X509>();
}
return stream->GetPeerCertificate();
Defer shutdown ([&sslConn]() { sslConn.shutdown(); });
return sslConn.GetPeerCertificate();
}
int PkiUtility::WriteCert(const std::shared_ptr<X509>& cert, const String& trustedfile)
@ -142,22 +148,10 @@ int PkiUtility::GenTicket(const String& cn, const String& salt, std::ostream& ti
int PkiUtility::RequestCertificate(const String& host, const String& port, const String& keyfile,
const String& certfile, const String& cafile, const std::shared_ptr<X509>& trustedCert, const String& ticket)
{
TcpSocket::Ptr client = new TcpSocket();
std::shared_ptr<boost::asio::ssl::context> sslContext;
try {
client->Connect(host, port);
} catch (const std::exception& ex) {
Log(LogCritical, "cli")
<< "Cannot connect to host '" << host << "' on port '" << port << "'";
Log(LogDebug, "cli")
<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
return 1;
}
std::shared_ptr<SSL_CTX> sslContext;
try {
sslContext = MakeSSLContext(certfile, keyfile);
sslContext = MakeAsioSslContext(certfile, keyfile);
} catch (const std::exception& ex) {
Log(LogCritical, "cli")
<< "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "'.";
@ -166,17 +160,31 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const
return 1;
}
TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext);
auto stream (std::make_shared<AsioTlsStream>(IoEngine::Get().GetIoService(), *sslContext, host));
try {
stream->Handshake();
Connect(stream->lowest_layer(), host, port);
} catch (const std::exception& ex) {
Log(LogCritical, "cli")
<< "Cannot connect to host '" << host << "' on port '" << port << "'";
Log(LogDebug, "cli")
<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
return 1;
}
auto& sslConn (stream->next_layer());
try {
sslConn.handshake(sslConn.client);
} catch (const std::exception& ex) {
Log(LogCritical, "cli")
<< "Client TLS handshake failed: " << DiagnosticInformation(ex, false);
return 1;
}
std::shared_ptr<X509> peerCert = stream->GetPeerCertificate();
Defer shutdown ([&sslConn]() { sslConn.shutdown(); });
auto peerCert (sslConn.GetPeerCertificate());
if (X509_cmp(peerCert.get(), trustedCert.get())) {
Log(LogCritical, "cli", "Peer certificate does not match trusted certificate.");
@ -196,36 +204,32 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const
{ "params", params }
});
JsonRpc::SendMessage(stream, request);
String jsonString;
Dictionary::Ptr response;
StreamReadContext src;
for (;;) {
StreamReadStatus srs = JsonRpc::ReadMessage(stream, &jsonString, src);
try {
JsonRpc::SendMessage(stream, request);
stream->flush();
if (srs == StatusEof)
break;
for (;;) {
response = JsonRpc::DecodeMessage(JsonRpc::ReadMessage(stream));
if (srs != StatusNewItem)
continue;
response = JsonRpc::DecodeMessage(jsonString);
if (response && response->Contains("error")) {
Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug).");
if (response && response->Contains("error")) {
Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug).");
#ifdef I2_DEBUG
/* we shouldn't expose master errors to the user in production environments */
Log(LogCritical, "cli", response->Get("error"));
/* we shouldn't expose master errors to the user in production environments */
Log(LogCritical, "cli", response->Get("error"));
#endif /* I2_DEBUG */
return 1;
return 1;
}
if (response && (response->Get("id") != msgid))
continue;
break;
}
if (response && (response->Get("id") != msgid))
continue;
break;
} catch (...) {
Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log.");
return 1;
}
if (!response) {