From 00d859234e64f76c0e3bc4da20be71e0d7e3bbe5 Mon Sep 17 00:00:00 2001
From: "Alexander A. Klimov" <alexander.klimov@icinga.com>
Date: Mon, 25 Feb 2019 18:58:04 +0100
Subject: [PATCH] Use new I/O engine in PkiUtility::FetchCert() and
 PkiUtility::RequestCertificate()

---
 lib/remote/pkiutility.cpp | 118 ++++++++++++++++++++------------------
 1 file changed, 61 insertions(+), 57 deletions(-)

diff --git a/lib/remote/pkiutility.cpp b/lib/remote/pkiutility.cpp
index e1e785288..c08989dd8 100644
--- a/lib/remote/pkiutility.cpp
+++ b/lib/remote/pkiutility.cpp
@@ -2,8 +2,11 @@
 
 #include "remote/pkiutility.hpp"
 #include "remote/apilistener.hpp"
+#include "base/defer.hpp"
+#include "base/io-engine.hpp"
 #include "base/logger.hpp"
 #include "base/application.hpp"
+#include "base/tcpsocket.hpp"
 #include "base/tlsutility.hpp"
 #include "base/console.hpp"
 #include "base/tlsstream.hpp"
@@ -14,6 +17,7 @@
 #include "remote/jsonrpc.hpp"
 #include <fstream>
 #include <iostream>
+#include <boost/asio/ssl/context.hpp>
 
 using namespace icinga;
 
@@ -76,22 +80,10 @@ int PkiUtility::SignCsr(const String& csrfile, const String& certfile)
 
 std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& port)
 {
-	TcpSocket::Ptr client = new TcpSocket();
+	std::shared_ptr<boost::asio::ssl::context> sslContext;
 
 	try {
-		client->Connect(host, port);
-	} catch (const std::exception& ex) {
-		Log(LogCritical, "pki")
-			<< "Cannot connect to host '" << host << "' on port '" << port << "'";
-		Log(LogDebug, "pki")
-			<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
-		return std::shared_ptr<X509>();
-	}
-
-	std::shared_ptr<SSL_CTX> sslContext;
-
-	try {
-		sslContext = MakeSSLContext();
+		sslContext = MakeAsioSslContext();
 	} catch (const std::exception& ex) {
 		Log(LogCritical, "pki")
 			<< "Cannot make SSL context.";
@@ -100,17 +92,31 @@ std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& po
 		return std::shared_ptr<X509>();
 	}
 
-	TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext);
+	auto stream (std::make_shared<AsioTlsStream>(IoEngine::Get().GetIoService(), *sslContext, host));
 
 	try {
-		stream->Handshake();
+		Connect(stream->lowest_layer(), host, port);
+	} catch (const std::exception& ex) {
+		Log(LogCritical, "pki")
+			<< "Cannot connect to host '" << host << "' on port '" << port << "'";
+		Log(LogDebug, "pki")
+			<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
+		return std::shared_ptr<X509>();
+	}
+
+	auto& sslConn (stream->next_layer());
+
+	try {
+		sslConn.handshake(sslConn.client);
 	} catch (const std::exception& ex) {
 		Log(LogCritical, "pki")
 			<< "Client TLS handshake failed. (" << ex.what() << ")";
 		return std::shared_ptr<X509>();
 	}
 
-	return stream->GetPeerCertificate();
+	Defer shutdown ([&sslConn]() { sslConn.shutdown(); });
+
+	return sslConn.GetPeerCertificate();
 }
 
 int PkiUtility::WriteCert(const std::shared_ptr<X509>& cert, const String& trustedfile)
@@ -142,22 +148,10 @@ int PkiUtility::GenTicket(const String& cn, const String& salt, std::ostream& ti
 int PkiUtility::RequestCertificate(const String& host, const String& port, const String& keyfile,
 	const String& certfile, const String& cafile, const std::shared_ptr<X509>& trustedCert, const String& ticket)
 {
-	TcpSocket::Ptr client = new TcpSocket();
+	std::shared_ptr<boost::asio::ssl::context> sslContext;
 
 	try {
-		client->Connect(host, port);
-	} catch (const std::exception& ex) {
-		Log(LogCritical, "cli")
-			<< "Cannot connect to host '" << host << "' on port '" << port << "'";
-		Log(LogDebug, "cli")
-			<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
-		return 1;
-	}
-
-	std::shared_ptr<SSL_CTX> sslContext;
-
-	try {
-		sslContext = MakeSSLContext(certfile, keyfile);
+		sslContext = MakeAsioSslContext(certfile, keyfile);
 	} catch (const std::exception& ex) {
 		Log(LogCritical, "cli")
 			<< "Cannot make SSL context for cert path: '" << certfile << "' key path: '" << keyfile << "' ca path: '" << cafile << "'.";
@@ -166,17 +160,31 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const
 		return 1;
 	}
 
-	TlsStream::Ptr stream = new TlsStream(client, host, RoleClient, sslContext);
+	auto stream (std::make_shared<AsioTlsStream>(IoEngine::Get().GetIoService(), *sslContext, host));
 
 	try {
-		stream->Handshake();
+		Connect(stream->lowest_layer(), host, port);
+	} catch (const std::exception& ex) {
+		Log(LogCritical, "cli")
+			<< "Cannot connect to host '" << host << "' on port '" << port << "'";
+		Log(LogDebug, "cli")
+			<< "Cannot connect to host '" << host << "' on port '" << port << "':\n" << DiagnosticInformation(ex);
+		return 1;
+	}
+
+	auto& sslConn (stream->next_layer());
+
+	try {
+		sslConn.handshake(sslConn.client);
 	} catch (const std::exception& ex) {
 		Log(LogCritical, "cli")
 			<< "Client TLS handshake failed: " << DiagnosticInformation(ex, false);
 		return 1;
 	}
 
-	std::shared_ptr<X509> peerCert = stream->GetPeerCertificate();
+	Defer shutdown ([&sslConn]() { sslConn.shutdown(); });
+
+	auto peerCert (sslConn.GetPeerCertificate());
 
 	if (X509_cmp(peerCert.get(), trustedCert.get())) {
 		Log(LogCritical, "cli", "Peer certificate does not match trusted certificate.");
@@ -196,36 +204,32 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const
 		{ "params", params }
 	});
 
-	JsonRpc::SendMessage(stream, request);
-
-	String jsonString;
 	Dictionary::Ptr response;
-	StreamReadContext src;
 
-	for (;;) {
-		StreamReadStatus srs = JsonRpc::ReadMessage(stream, &jsonString, src);
+	try {
+		JsonRpc::SendMessage(stream, request);
+		stream->flush();
 
-		if (srs == StatusEof)
-			break;
+		for (;;) {
+			response = JsonRpc::DecodeMessage(JsonRpc::ReadMessage(stream));
 
-		if (srs != StatusNewItem)
-			continue;
-
-		response = JsonRpc::DecodeMessage(jsonString);
-
-		if (response && response->Contains("error")) {
-			Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug).");
+			if (response && response->Contains("error")) {
+				Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log (notice or debug).");
 #ifdef I2_DEBUG
-			/* we shouldn't expose master errors to the user in production environments */
-			Log(LogCritical, "cli", response->Get("error"));
+				/* we shouldn't expose master errors to the user in production environments */
+				Log(LogCritical, "cli", response->Get("error"));
 #endif /* I2_DEBUG */
-			return 1;
+				return 1;
+			}
+
+			if (response && (response->Get("id") != msgid))
+				continue;
+
+			break;
 		}
-
-		if (response && (response->Get("id") != msgid))
-			continue;
-
-		break;
+	} catch (...) {
+		Log(LogCritical, "cli", "Could not fetch valid response. Please check the master log.");
+		return 1;
 	}
 
 	if (!response) {