AsioTlsStream: inherit from SharedObject

All usages of `AsioTlsStream` were already using `Shared<AsioTlsStream>` to
keep a reference-counted instance. This commit moves the reference counting to
`AsioTlsStream` itself by inheriting from `SharedObject`. This will allow to
implement methods making use of the fact that these objects are
reference-counted.

The changes outside of `lib/base/tlsstream.hpp` are merely replacing
`Shared<AsioTlsStream>::Ptr` with `AsioTlsStream::Ptr` everywhere.
This commit is contained in:
Julian Brost 2024-02-12 15:45:16 +01:00
parent 7a20d987f6
commit a85c188fed
21 changed files with 61 additions and 54 deletions

View File

@ -126,7 +126,7 @@ size_t NetString::WriteStringToStream(const Stream::Ptr& stream, const String& s
* @exception invalid_argument The input stream is invalid.
* @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
*/
String NetString::ReadStringFromStream(const Shared<AsioTlsStream>::Ptr& stream,
String NetString::ReadStringFromStream(const AsioTlsStream::Ptr& stream,
ssize_t maxMessageLength)
{
namespace asio = boost::asio;
@ -205,7 +205,7 @@ String NetString::ReadStringFromStream(const Shared<AsioTlsStream>::Ptr& stream,
* @exception invalid_argument The input stream is invalid.
* @see https://github.com/PeterScott/netstring-c/blob/master/netstring.c
*/
String NetString::ReadStringFromStream(const Shared<AsioTlsStream>::Ptr& stream,
String NetString::ReadStringFromStream(const AsioTlsStream::Ptr& stream,
boost::asio::yield_context yc, ssize_t maxMessageLength)
{
namespace asio = boost::asio;
@ -284,7 +284,7 @@ String NetString::ReadStringFromStream(const Shared<AsioTlsStream>::Ptr& stream,
*
* @return The amount of bytes written.
*/
size_t NetString::WriteStringToStream(const Shared<AsioTlsStream>::Ptr& stream, const String& str)
size_t NetString::WriteStringToStream(const AsioTlsStream::Ptr& stream, const String& str)
{
namespace asio = boost::asio;
@ -307,7 +307,7 @@ size_t NetString::WriteStringToStream(const Shared<AsioTlsStream>::Ptr& stream,
*
* @return The amount of bytes written.
*/
size_t NetString::WriteStringToStream(const Shared<AsioTlsStream>::Ptr& stream, const String& str, boost::asio::yield_context yc)
size_t NetString::WriteStringToStream(const AsioTlsStream::Ptr& stream, const String& str, boost::asio::yield_context yc)
{
namespace asio = boost::asio;

View File

@ -26,12 +26,12 @@ class NetString
public:
static StreamReadStatus ReadStringFromStream(const Stream::Ptr& stream, String *message, StreamReadContext& context,
bool may_wait = false, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const Shared<AsioTlsStream>::Ptr& stream, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const Shared<AsioTlsStream>::Ptr& stream,
static String ReadStringFromStream(const AsioTlsStream::Ptr& stream, ssize_t maxMessageLength = -1);
static String ReadStringFromStream(const AsioTlsStream::Ptr& stream,
boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
static size_t WriteStringToStream(const Stream::Ptr& stream, const String& message);
static size_t WriteStringToStream(const Shared<AsioTlsStream>::Ptr& stream, const String& message);
static size_t WriteStringToStream(const Shared<AsioTlsStream>::Ptr& stream, const String& message, boost::asio::yield_context yc);
static size_t WriteStringToStream(const AsioTlsStream::Ptr& stream, const String& message);
static size_t WriteStringToStream(const AsioTlsStream::Ptr& stream, const String& message, boost::asio::yield_context yc);
static void WriteStringToStream(std::ostream& stream, const String& message);
private:

View File

@ -4,7 +4,7 @@
#define TLSSTREAM_H
#include "base/i2-base.hpp"
#include "base/shared.hpp"
#include "base/shared-object.hpp"
#include "base/socket.hpp"
#include "base/stream.hpp"
#include "base/tlsutility.hpp"
@ -102,15 +102,22 @@ private:
void BeforeHandshake(handshake_type type);
};
class AsioTlsStream : public boost::asio::buffered_stream<UnbufferedAsioTlsStream>
class AsioTlsStream : public SharedObject, public boost::asio::buffered_stream<UnbufferedAsioTlsStream>
{
public:
DECLARE_PTR_TYPEDEFS(AsioTlsStream);
inline
AsioTlsStream(boost::asio::io_context& ioContext, boost::asio::ssl::context& sslContext, const String& hostname = String())
: AsioTlsStream(UnbufferedAsioTlsStreamParams{ioContext, sslContext, hostname})
{
}
static AsioTlsStream::Ptr Make(boost::asio::io_context& ioContext, boost::asio::ssl::context& sslContext, const String& hostname = String())
{
return new AsioTlsStream(ioContext, sslContext, hostname);
}
private:
inline
AsioTlsStream(UnbufferedAsioTlsStreamParams init)
@ -120,7 +127,7 @@ private:
};
typedef boost::asio::buffered_stream<boost::asio::ip::tcp::socket> AsioTcpStream;
typedef std::pair<Shared<AsioTlsStream>::Ptr, Shared<AsioTcpStream>::Ptr> OptionalTlsStream;
typedef std::pair<AsioTlsStream::Ptr, Shared<AsioTcpStream>::Ptr> OptionalTlsStream;
}

View File

@ -42,7 +42,7 @@ namespace po = boost::program_options;
static ScriptFrame *l_ScriptFrame;
static Url::Ptr l_Url;
static Shared<AsioTlsStream>::Ptr l_TlsStream;
static AsioTlsStream::Ptr l_TlsStream;
static String l_Session;
REGISTER_CLICOMMAND("console", ConsoleCommand);
@ -522,7 +522,7 @@ incomplete:
*
* @returns AsioTlsStream pointer for future HTTP connections.
*/
Shared<AsioTlsStream>::Ptr ConsoleCommand::Connect()
AsioTlsStream::Ptr ConsoleCommand::Connect()
{
Shared<boost::asio::ssl::context>::Ptr sslContext;
@ -537,7 +537,7 @@ Shared<AsioTlsStream>::Ptr ConsoleCommand::Connect()
String host = l_Url->GetHost();
String port = l_Url->GetPort();
Shared<AsioTlsStream>::Ptr stream = Shared<AsioTlsStream>::Make(IoEngine::Get().GetIoContext(), *sslContext, host);
AsioTlsStream::Ptr stream = AsioTlsStream::Make(IoEngine::Get().GetIoContext(), *sslContext, host);
try {
icinga::Connect(stream->lowest_layer(), host, port);

View File

@ -41,7 +41,7 @@ private:
mutable std::mutex m_Mutex;
mutable std::condition_variable m_CV;
static Shared<AsioTlsStream>::Ptr Connect();
static AsioTlsStream::Ptr Connect();
static Value ExecuteScript(const String& session, const String& command, bool sandboxed);
static Array::Ptr AutoCompleteScript(const String& session, const String& command, bool sandboxed);

View File

@ -315,7 +315,7 @@ void RedisConnection::Connect(asio::yield_context& yc)
Log(m_Parent ? LogNotice : LogInformation, "IcingaDB")
<< "Trying to connect to Redis server (async, TLS) on host '" << m_Host << ":" << m_Port << "'";
auto conn (Shared<AsioTlsStream>::Make(m_Strand.context(), *m_TLSContext, m_Host));
auto conn (AsioTlsStream::Make(m_Strand.context(), *m_TLSContext, m_Host));
auto& tlsConn (conn->next_layer());
auto connectTimeout (MakeTimeout(conn));
Defer cancelTimeout ([&connectTimeout]() { connectTimeout->Cancel(); });

View File

@ -244,7 +244,7 @@ namespace icinga
boost::asio::io_context::strand m_Strand;
Shared<TcpConn>::Ptr m_TcpConn;
Shared<UnixConn>::Ptr m_UnixConn;
Shared<AsioTlsStream>::Ptr m_TlsConn;
AsioTlsStream::Ptr m_TlsConn;
Atomic<bool> m_Connecting, m_Connected, m_Started;
struct {

View File

@ -449,7 +449,7 @@ void IfwApiCheckTask::ScriptFunc(const Checkable::Ptr& checkable, const CheckRes
return;
}
auto conn (Shared<AsioTlsStream>::Make(io, *ctx, expectedSan));
auto conn (AsioTlsStream::Make(io, *ctx, expectedSan));
IoEngine::SpawnCoroutine(
*strand,

View File

@ -612,7 +612,7 @@ OptionalTlsStream ElasticsearchWriter::Connect()
throw;
}
stream.first = Shared<AsioTlsStream>::Make(IoEngine::Get().GetIoContext(), *sslContext, GetHost());
stream.first = AsioTlsStream::Make(IoEngine::Get().GetIoContext(), *sslContext, GetHost());
} else {
stream.second = Shared<AsioTcpStream>::Make(IoEngine::Get().GetIoContext());

View File

@ -184,7 +184,7 @@ void GelfWriter::ReconnectInternal()
throw;
}
m_Stream.first = Shared<AsioTlsStream>::Make(IoEngine::Get().GetIoContext(), *sslContext, GetHost());
m_Stream.first = AsioTlsStream::Make(IoEngine::Get().GetIoContext(), *sslContext, GetHost());
} else {
m_Stream.second = Shared<AsioTcpStream>::Make(IoEngine::Get().GetIoContext());

View File

@ -158,7 +158,7 @@ OptionalTlsStream InfluxdbCommonWriter::Connect()
throw;
}
stream.first = Shared<AsioTlsStream>::Make(IoEngine::Get().GetIoContext(), *sslContext, GetHost());
stream.first = AsioTlsStream::Make(IoEngine::Get().GetIoContext(), *sslContext, GetHost());
} else {
stream.second = Shared<AsioTcpStream>::Make(IoEngine::Get().GetIoContext());

View File

@ -526,7 +526,7 @@ void ApiListener::ListenerCoroutineProc(boost::asio::yield_context yc, const Sha
}
boost::shared_lock<decltype(m_SSLContextMutex)> lock (m_SSLContextMutex);
auto sslConn (Shared<AsioTlsStream>::Make(io, *m_SSLContext));
auto sslConn (AsioTlsStream::Make(io, *m_SSLContext));
lock.unlock();
sslConn->lowest_layer() = std::move(socket);
@ -581,7 +581,7 @@ void ApiListener::AddConnection(const Endpoint::Ptr& endpoint)
try {
boost::shared_lock<decltype(m_SSLContextMutex)> lock (m_SSLContextMutex);
auto sslConn (Shared<AsioTlsStream>::Make(io, *m_SSLContext, endpoint->GetName()));
auto sslConn (AsioTlsStream::Make(io, *m_SSLContext, endpoint->GetName()));
lock.unlock();
@ -615,7 +615,7 @@ void ApiListener::AddConnection(const Endpoint::Ptr& endpoint)
void ApiListener::NewClientHandler(
boost::asio::yield_context yc, const Shared<boost::asio::io_context::strand>::Ptr& strand,
const Shared<AsioTlsStream>::Ptr& client, const String& hostname, ConnectionRole role
const AsioTlsStream::Ptr& client, const String& hostname, ConnectionRole role
)
{
try {
@ -654,7 +654,7 @@ static const auto l_MyCapabilities (
*/
void ApiListener::NewClientHandlerInternal(
boost::asio::yield_context yc, const Shared<boost::asio::io_context::strand>::Ptr& strand,
const Shared<AsioTlsStream>::Ptr& client, const String& hostname, ConnectionRole role
const AsioTlsStream::Ptr& client, const String& hostname, ConnectionRole role
)
{
namespace asio = boost::asio;

View File

@ -191,11 +191,11 @@ private:
void NewClientHandler(
boost::asio::yield_context yc, const Shared<boost::asio::io_context::strand>::Ptr& strand,
const Shared<AsioTlsStream>::Ptr& client, const String& hostname, ConnectionRole role
const AsioTlsStream::Ptr& client, const String& hostname, ConnectionRole role
);
void NewClientHandlerInternal(
boost::asio::yield_context yc, const Shared<boost::asio::io_context::strand>::Ptr& strand,
const Shared<AsioTlsStream>::Ptr& client, const String& hostname, ConnectionRole role
const AsioTlsStream::Ptr& client, const String& hostname, ConnectionRole role
);
void ListenerCoroutineProc(boost::asio::yield_context yc, const Shared<boost::asio::ip::tcp::acceptor>::Ptr& server);

View File

@ -35,12 +35,12 @@ using namespace icinga;
auto const l_ServerHeader ("Icinga/" + Application::GetAppVersion());
HttpServerConnection::HttpServerConnection(const String& identity, bool authenticated, const Shared<AsioTlsStream>::Ptr& stream)
HttpServerConnection::HttpServerConnection(const String& identity, bool authenticated, const AsioTlsStream::Ptr& stream)
: HttpServerConnection(identity, authenticated, stream, IoEngine::Get().GetIoContext())
{
}
HttpServerConnection::HttpServerConnection(const String& identity, bool authenticated, const Shared<AsioTlsStream>::Ptr& stream, boost::asio::io_context& io)
HttpServerConnection::HttpServerConnection(const String& identity, bool authenticated, const AsioTlsStream::Ptr& stream, boost::asio::io_context& io)
: m_Stream(stream), m_Seen(Utility::GetTime()), m_IoStrand(io), m_ShuttingDown(false), m_HasStartedStreaming(false),
m_CheckLivenessTimer(io)
{

View File

@ -25,7 +25,7 @@ class HttpServerConnection final : public Object
public:
DECLARE_PTR_TYPEDEFS(HttpServerConnection);
HttpServerConnection(const String& identity, bool authenticated, const Shared<AsioTlsStream>::Ptr& stream);
HttpServerConnection(const String& identity, bool authenticated, const AsioTlsStream::Ptr& stream);
void Start();
void Disconnect();
@ -35,7 +35,7 @@ public:
private:
ApiUser::Ptr m_ApiUser;
Shared<AsioTlsStream>::Ptr m_Stream;
AsioTlsStream::Ptr m_Stream;
double m_Seen;
String m_PeerAddress;
boost::asio::io_context::strand m_IoStrand;
@ -43,7 +43,7 @@ private:
bool m_HasStartedStreaming;
boost::asio::deadline_timer m_CheckLivenessTimer;
HttpServerConnection(const String& identity, bool authenticated, const Shared<AsioTlsStream>::Ptr& stream, boost::asio::io_context& io);
HttpServerConnection(const String& identity, bool authenticated, const AsioTlsStream::Ptr& stream, boost::asio::io_context& io);
void ProcessMessages(boost::asio::yield_context yc);
void CheckLiveness(boost::asio::yield_context yc);

View File

@ -52,7 +52,7 @@ static bool GetDebugJsonRpcCached()
*
* @return The amount of bytes sent.
*/
size_t JsonRpc::SendMessage(const Shared<AsioTlsStream>::Ptr& stream, const Dictionary::Ptr& message)
size_t JsonRpc::SendMessage(const AsioTlsStream::Ptr& stream, const Dictionary::Ptr& message)
{
String json = JsonEncode(message);
@ -71,7 +71,7 @@ size_t JsonRpc::SendMessage(const Shared<AsioTlsStream>::Ptr& stream, const Dict
*
* @return The amount of bytes sent.
*/
size_t JsonRpc::SendMessage(const Shared<AsioTlsStream>::Ptr& stream, const Dictionary::Ptr& message, boost::asio::yield_context yc)
size_t JsonRpc::SendMessage(const AsioTlsStream::Ptr& stream, const Dictionary::Ptr& message, boost::asio::yield_context yc)
{
return JsonRpc::SendRawMessage(stream, JsonEncode(message), yc);
}
@ -85,7 +85,7 @@ size_t JsonRpc::SendMessage(const Shared<AsioTlsStream>::Ptr& stream, const Dict
*
* @return bytes sent
*/
size_t JsonRpc::SendRawMessage(const Shared<AsioTlsStream>::Ptr& stream, const String& json, boost::asio::yield_context yc)
size_t JsonRpc::SendRawMessage(const AsioTlsStream::Ptr& stream, const String& json, boost::asio::yield_context yc)
{
#ifdef I2_DEBUG
if (GetDebugJsonRpcCached())
@ -104,7 +104,7 @@ size_t JsonRpc::SendRawMessage(const Shared<AsioTlsStream>::Ptr& stream, const S
* @return A JSON string
*/
String JsonRpc::ReadMessage(const Shared<AsioTlsStream>::Ptr& stream, ssize_t maxMessageLength)
String JsonRpc::ReadMessage(const AsioTlsStream::Ptr& stream, ssize_t maxMessageLength)
{
String jsonString = NetString::ReadStringFromStream(stream, maxMessageLength);
@ -125,7 +125,7 @@ String JsonRpc::ReadMessage(const Shared<AsioTlsStream>::Ptr& stream, ssize_t ma
*
* @return A JSON string
*/
String JsonRpc::ReadMessage(const Shared<AsioTlsStream>::Ptr& stream, boost::asio::yield_context yc, ssize_t maxMessageLength)
String JsonRpc::ReadMessage(const AsioTlsStream::Ptr& stream, boost::asio::yield_context yc, ssize_t maxMessageLength)
{
String jsonString = NetString::ReadStringFromStream(stream, yc, maxMessageLength);

View File

@ -21,12 +21,12 @@ namespace icinga
class JsonRpc
{
public:
static size_t SendMessage(const Shared<AsioTlsStream>::Ptr& stream, const Dictionary::Ptr& message);
static size_t SendMessage(const Shared<AsioTlsStream>::Ptr& stream, const Dictionary::Ptr& message, boost::asio::yield_context yc);
static size_t SendRawMessage(const Shared<AsioTlsStream>::Ptr& stream, const String& json, boost::asio::yield_context yc);
static size_t SendMessage(const AsioTlsStream::Ptr& stream, const Dictionary::Ptr& message);
static size_t SendMessage(const AsioTlsStream::Ptr& stream, const Dictionary::Ptr& message, boost::asio::yield_context yc);
static size_t SendRawMessage(const AsioTlsStream::Ptr& stream, const String& json, boost::asio::yield_context yc);
static String ReadMessage(const Shared<AsioTlsStream>::Ptr& stream, ssize_t maxMessageLength = -1);
static String ReadMessage(const Shared<AsioTlsStream>::Ptr& stream, boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
static String ReadMessage(const AsioTlsStream::Ptr& stream, ssize_t maxMessageLength = -1);
static String ReadMessage(const AsioTlsStream::Ptr& stream, boost::asio::yield_context yc, ssize_t maxMessageLength = -1);
static Dictionary::Ptr DecodeMessage(const String& message);

View File

@ -30,13 +30,13 @@ REGISTER_APIFUNCTION(SetLogPosition, log, &SetLogPositionHandler);
static RingBuffer l_TaskStats (15 * 60);
JsonRpcConnection::JsonRpcConnection(const String& identity, bool authenticated,
const Shared<AsioTlsStream>::Ptr& stream, ConnectionRole role)
const AsioTlsStream::Ptr& stream, ConnectionRole role)
: JsonRpcConnection(identity, authenticated, stream, role, IoEngine::Get().GetIoContext())
{
}
JsonRpcConnection::JsonRpcConnection(const String& identity, bool authenticated,
const Shared<AsioTlsStream>::Ptr& stream, ConnectionRole role, boost::asio::io_context& io)
const AsioTlsStream::Ptr& stream, ConnectionRole role, boost::asio::io_context& io)
: m_Identity(identity), m_Authenticated(authenticated), m_Stream(stream), m_Role(role),
m_Timestamp(Utility::GetTime()), m_Seen(Utility::GetTime()), m_IoStrand(io),
m_OutgoingMessagesQueued(io), m_WriterDone(io), m_ShuttingDown(false),
@ -151,7 +151,7 @@ Endpoint::Ptr JsonRpcConnection::GetEndpoint() const
return m_Endpoint;
}
Shared<AsioTlsStream>::Ptr JsonRpcConnection::GetStream() const
AsioTlsStream::Ptr JsonRpcConnection::GetStream() const
{
return m_Stream;
}

View File

@ -43,7 +43,7 @@ class JsonRpcConnection final : public Object
public:
DECLARE_PTR_TYPEDEFS(JsonRpcConnection);
JsonRpcConnection(const String& identity, bool authenticated, const Shared<AsioTlsStream>::Ptr& stream, ConnectionRole role);
JsonRpcConnection(const String& identity, bool authenticated, const AsioTlsStream::Ptr& stream, ConnectionRole role);
void Start();
@ -51,7 +51,7 @@ public:
String GetIdentity() const;
bool IsAuthenticated() const;
Endpoint::Ptr GetEndpoint() const;
Shared<AsioTlsStream>::Ptr GetStream() const;
AsioTlsStream::Ptr GetStream() const;
ConnectionRole GetRole() const;
void Disconnect();
@ -69,7 +69,7 @@ private:
String m_Identity;
bool m_Authenticated;
Endpoint::Ptr m_Endpoint;
Shared<AsioTlsStream>::Ptr m_Stream;
AsioTlsStream::Ptr m_Stream;
ConnectionRole m_Role;
double m_Timestamp;
double m_Seen;
@ -80,7 +80,7 @@ private:
Atomic<bool> m_ShuttingDown;
boost::asio::deadline_timer m_CheckLivenessTimer, m_HeartbeatTimer;
JsonRpcConnection(const String& identity, bool authenticated, const Shared<AsioTlsStream>::Ptr& stream, ConnectionRole role, boost::asio::io_context& io);
JsonRpcConnection(const String& identity, bool authenticated, const AsioTlsStream::Ptr& stream, ConnectionRole role, boost::asio::io_context& io);
void HandleIncomingMessages(boost::asio::yield_context yc);
void WriteOutgoingMessages(boost::asio::yield_context yc);

View File

@ -95,7 +95,7 @@ std::shared_ptr<X509> PkiUtility::FetchCert(const String& host, const String& po
return std::shared_ptr<X509>();
}
auto stream (Shared<AsioTlsStream>::Make(IoEngine::Get().GetIoContext(), *sslContext, host));
auto stream (AsioTlsStream::Make(IoEngine::Get().GetIoContext(), *sslContext, host));
try {
Connect(stream->lowest_layer(), host, port);
@ -163,7 +163,7 @@ int PkiUtility::RequestCertificate(const String& host, const String& port, const
return 1;
}
auto stream (Shared<AsioTlsStream>::Make(IoEngine::Get().GetIoContext(), *sslContext, host));
auto stream (AsioTlsStream::Make(IoEngine::Get().GetIoContext(), *sslContext, host));
try {
Connect(stream->lowest_layer(), host, port);

View File

@ -174,7 +174,7 @@ static int FormatOutput(const Dictionary::Ptr& result)
*
* @returns AsioTlsStream pointer for future HTTP connections.
*/
static Shared<AsioTlsStream>::Ptr Connect(const String& host, const String& port)
static AsioTlsStream::Ptr Connect(const String& host, const String& port)
{
Shared<boost::asio::ssl::context>::Ptr sslContext;
@ -186,7 +186,7 @@ static Shared<AsioTlsStream>::Ptr Connect(const String& host, const String& port
throw;
}
Shared<AsioTlsStream>::Ptr stream = Shared<AsioTlsStream>::Make(IoEngine::Get().GetIoContext(), *sslContext, host);
AsioTlsStream::Ptr stream = AsioTlsStream::Make(IoEngine::Get().GetIoContext(), *sslContext, host);
try {
icinga::Connect(stream->lowest_layer(), host, port);
@ -338,7 +338,7 @@ static Dictionary::Ptr FetchData(const String& host, const String& port, const S
namespace beast = boost::beast;
namespace http = beast::http;
Shared<AsioTlsStream>::Ptr tlsStream;
AsioTlsStream::Ptr tlsStream;
try {
tlsStream = Connect(host, port);