From ee5954726dc1c054ac952b3a7e45e2ac972cfb0b Mon Sep 17 00:00:00 2001 From: Jean Flach Date: Thu, 8 Feb 2018 14:54:52 +0100 Subject: [PATCH] Authenticate API user before parsing body --- lib/remote/apiuser.cpp | 28 ++++ lib/remote/apiuser.hpp | 1 + lib/remote/httprequest.cpp | 72 +++++----- lib/remote/httprequest.hpp | 1 + lib/remote/httpserverconnection.cpp | 210 +++++++++++++--------------- lib/remote/httpserverconnection.hpp | 6 +- 6 files changed, 169 insertions(+), 149 deletions(-) diff --git a/lib/remote/apiuser.cpp b/lib/remote/apiuser.cpp index 183291af3..416eedb34 100644 --- a/lib/remote/apiuser.cpp +++ b/lib/remote/apiuser.cpp @@ -20,6 +20,7 @@ #include "remote/apiuser.hpp" #include "remote/apiuser-ti.cpp" #include "base/configtype.hpp" +#include "base/base64.hpp" using namespace icinga; @@ -34,3 +35,30 @@ ApiUser::Ptr ApiUser::GetByClientCN(const String& cn) return nullptr; } + +ApiUser::Ptr ApiUser::GetByAuthHeader(const String& auth_header) { + String::SizeType pos = auth_header.FindFirstOf(" "); + String username, password; + + if (pos != String::NPos && auth_header.SubStr(0, pos) == "Basic") { + String credentials_base64 = auth_header.SubStr(pos + 1); + String credentials = Base64::Decode(credentials_base64); + + String::SizeType cpos = credentials.FindFirstOf(":"); + + if (cpos != String::NPos) { + username = credentials.SubStr(0, cpos); + password = credentials.SubStr(cpos + 1); + } + } + + const ApiUser::Ptr& user = ApiUser::GetByName(username); + + /* Deny authentication if 1) given password is empty 2) configured password does not match. */ + if (password.IsEmpty()) + return nullptr; + else if (user && user->GetPassword() != password) + return nullptr; + + return user; +} diff --git a/lib/remote/apiuser.hpp b/lib/remote/apiuser.hpp index 755273bf4..15b1c41e3 100644 --- a/lib/remote/apiuser.hpp +++ b/lib/remote/apiuser.hpp @@ -36,6 +36,7 @@ public: DECLARE_OBJECTNAME(ApiUser); static ApiUser::Ptr GetByClientCN(const String& cn); + static ApiUser::Ptr GetByAuthHeader(const String& auth_header); }; } diff --git a/lib/remote/httprequest.cpp b/lib/remote/httprequest.cpp index 0d6c4a618..c5382d254 100644 --- a/lib/remote/httprequest.cpp +++ b/lib/remote/httprequest.cpp @@ -26,6 +26,7 @@ using namespace icinga; HttpRequest::HttpRequest(Stream::Ptr stream) : CompleteHeaders(false), + CompleteHeaderCheck(false), CompleteBody(false), ProtocolVersion(HttpVersion11), Headers(new Dictionary()), @@ -39,7 +40,7 @@ bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait) return false; if (m_State != HttpRequestStart && m_State != HttpRequestHeaders) - return false; + BOOST_THROW_EXCEPTION(std::runtime_error("Invalid HTTP state")); String line; StreamReadStatus srs = m_Stream->ReadLine(&line, src, may_wait); @@ -105,19 +106,19 @@ bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait) bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait) { - if (!m_Stream || m_State != HttpRequestBody) + if (!m_Stream) return false; + if (m_State != HttpRequestBody) + BOOST_THROW_EXCEPTION(std::runtime_error("Invalid HTTP state")); + /* we're done if the request doesn't contain a message body */ if (!Headers->Contains("content-length") && !Headers->Contains("transfer-encoding")) { CompleteBody = true; - return true; + return false; } else if (!m_Body) m_Body = new FIFO(); - if (CompleteBody) - return true; - if (Headers->Get("transfer-encoding") == "chunked") { if (!m_ChunkContext) m_ChunkContext = std::make_shared(std::ref(src)); @@ -135,39 +136,38 @@ bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait) if (size == 0) { CompleteBody = true; - return true; - } - } else { - if (src.Eof) - BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); - - if (src.MustRead) { - if (!src.FillFromStream(m_Stream, false)) { - src.Eof = true; - BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); - } - - src.MustRead = false; - } - - long length_indicator_signed = Convert::ToLong(Headers->Get("content-length")); - - if (length_indicator_signed < 0) - BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative.")); - - size_t length_indicator = length_indicator_signed; - - if (src.Size < length_indicator) { - src.MustRead = true; return false; - } - - m_Body->Write(src.Buffer, length_indicator); - src.DropData(length_indicator); - CompleteBody = true; - return true; + } else + return true; } + if (src.Eof) + BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); + + if (src.MustRead) { + if (!src.FillFromStream(m_Stream, false)) { + src.Eof = true; + BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); + } + + src.MustRead = false; + } + + long length_indicator_signed = Convert::ToLong(Headers->Get("content-length")); + + if (length_indicator_signed < 0) + BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative.")); + + size_t length_indicator = length_indicator_signed; + + if (src.Size < length_indicator) { + src.MustRead = true; + return false; + } + + m_Body->Write(src.Buffer, length_indicator); + src.DropData(length_indicator); + CompleteBody = true; return true; } diff --git a/lib/remote/httprequest.hpp b/lib/remote/httprequest.hpp index e8591474c..b456d7229 100644 --- a/lib/remote/httprequest.hpp +++ b/lib/remote/httprequest.hpp @@ -53,6 +53,7 @@ struct HttpRequest { public: bool CompleteHeaders; + bool CompleteHeaderCheck; bool CompleteBody; String RequestMethod; diff --git a/lib/remote/httpserverconnection.cpp b/lib/remote/httpserverconnection.cpp index 9e400da16..c22f0ee15 100644 --- a/lib/remote/httpserverconnection.cpp +++ b/lib/remote/httpserverconnection.cpp @@ -90,100 +90,105 @@ void HttpServerConnection::Disconnect() bool HttpServerConnection::ProcessMessage() { bool res; + HttpResponse response(m_Stream, m_CurrentRequest); - try { - res = m_CurrentRequest.ParseHeader(m_Context, false); - } catch (const std::invalid_argument& ex) { - HttpResponse response(m_Stream, m_CurrentRequest); - response.SetStatus(400, "Bad request"); - String msg = String("

Bad request

") + ex.what() + "

"; - response.WriteBody(msg.CStr(), msg.GetLength()); - response.Finish(); + if (!m_CurrentRequest.CompleteHeaders) { + try { + res = m_CurrentRequest.ParseHeader(m_Context, false); + } catch (const std::invalid_argument& ex) { + response.SetStatus(400, "Bad Request"); + String msg = String("

Bad Request

") + ex.what() + "

"; + response.WriteBody(msg.CStr(), msg.GetLength()); + response.Finish(); - m_Stream->Shutdown(); - return false; - } catch (const std::exception& ex) { - HttpResponse response(m_Stream, m_CurrentRequest); - response.SetStatus(400, "Bad request"); - String msg = "

Bad request

" + DiagnosticInformation(ex) + "

"; - response.WriteBody(msg.CStr(), msg.GetLength()); - response.Finish(); + m_Stream->Shutdown(); + return false; + } catch (const std::exception& ex) { + response.SetStatus(500, "Internal Server Error"); + String msg = "

Internal Server Error

" + DiagnosticInformation(ex) + "

"; + response.WriteBody(msg.CStr(), msg.GetLength()); + response.Finish(); - m_Stream->Shutdown(); - return false; + m_Stream->Shutdown(); + return false; + } + return res; } - if (m_CurrentRequest.CompleteHeaders) { - m_RequestQueue.Enqueue(std::bind(&HttpServerConnection::ProcessMessageAsync, - HttpServerConnection::Ptr(this), m_CurrentRequest)); - - m_Seen = Utility::GetTime(); - m_PendingRequests++; - - m_CurrentRequest.~HttpRequest(); - new (&m_CurrentRequest) HttpRequest(m_Stream); - - return true; - } - - return res; -} - -void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) -{ - String auth_header = request.Headers->Get("authorization"); - - String::SizeType pos = auth_header.FindFirstOf(" "); - String username, password; - - if (pos != String::NPos && auth_header.SubStr(0, pos) == "Basic") { - String credentials_base64 = auth_header.SubStr(pos + 1); - String credentials = Base64::Decode(credentials_base64); - - String::SizeType cpos = credentials.FindFirstOf(":"); - - if (cpos != String::NPos) { - username = credentials.SubStr(0, cpos); - password = credentials.SubStr(cpos + 1); + if (!m_CurrentRequest.CompleteHeaderCheck) { + m_CurrentRequest.CompleteHeaderCheck = true; + if (!ManageHeaders(response)) { + m_Stream->Shutdown(); + return false; } } - ApiUser::Ptr user; + if (!m_CurrentRequest.CompleteBody) { + try { + res = m_CurrentRequest.ParseBody(m_Context, false); + } catch (const std::invalid_argument& ex) { + response.SetStatus(400, "Bad Request"); + String msg = String("

Bad Request

") + ex.what() + "

"; + response.WriteBody(msg.CStr(), msg.GetLength()); + response.Finish(); + + m_Stream->Shutdown(); + return false; + } catch (const std::exception& ex) { + response.SetStatus(500, "Internal Server Error"); + String msg = "

Internal Server Error

" + DiagnosticInformation(ex) + "

"; + response.WriteBody(msg.CStr(), msg.GetLength()); + response.Finish(); + + m_Stream->Shutdown(); + return false; + } + return res; + } + + m_RequestQueue.Enqueue(std::bind(&HttpServerConnection::ProcessMessageAsync, + HttpServerConnection::Ptr(this), m_CurrentRequest, response, m_AuthenticatedUser)); + + m_Seen = Utility::GetTime(); + m_PendingRequests++; + + m_CurrentRequest.~HttpRequest(); + new (&m_CurrentRequest) HttpRequest(m_Stream); + + return false; +} + +bool HttpServerConnection::ManageHeaders(HttpResponse& response) +{ + if (m_CurrentRequest.Headers->Get("expect") == "100-continue") { + String continueResponse = "HTTP/1.1 100 Continue\r\n\r\n"; + m_Stream->Write(continueResponse.CStr(), continueResponse.GetLength()); + } /* client_cn matched. */ if (m_ApiUser) - user = m_ApiUser; - else { - user = ApiUser::GetByName(username); + m_AuthenticatedUser = m_ApiUser; + else + m_AuthenticatedUser = ApiUser::GetByAuthHeader(m_CurrentRequest.Headers->Get("authorization")); - /* Deny authentication if 1) given password is empty 2) configured password does not match. */ - if (password.IsEmpty()) - user.reset(); - else if (user && user->GetPassword() != password) - user.reset(); - } - - String requestUrl = request.RequestUrl->Format(); + String requestUrl = m_CurrentRequest.RequestUrl->Format(); Socket::Ptr socket = m_Stream->GetSocket(); Log(LogInformation, "HttpServerConnection") - << "Request: " << request.RequestMethod << " " << requestUrl + << "Request: " << m_CurrentRequest.RequestMethod << " " << requestUrl << " (from " << (socket ? socket->GetPeerAddress() : "") - << ", user: " << (user ? user->GetName() : "") << ")"; - - HttpResponse response(m_Stream, request); + << ", user: " << (m_AuthenticatedUser ? m_AuthenticatedUser->GetName() : "") << ")"; ApiListener::Ptr listener = ApiListener::GetInstance(); if (!listener) - return; + return false; Array::Ptr headerAllowOrigin = listener->GetAccessControlAllowOrigin(); if (headerAllowOrigin->GetLength() != 0) { - String origin = request.Headers->Get("origin"); - + String origin = m_CurrentRequest.Headers->Get("origin"); { ObjectLock olock(headerAllowOrigin); @@ -196,9 +201,9 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) if (listener->GetAccessControlAllowCredentials()) response.AddHeader("Access-Control-Allow-Credentials", "true"); - String accessControlRequestMethodHeader = request.Headers->Get("access-control-request-method"); + String accessControlRequestMethodHeader = m_CurrentRequest.Headers->Get("access-control-request-method"); - if (!accessControlRequestMethodHeader.IsEmpty()) { + if (m_CurrentRequest.RequestMethod == "OPTIONS" && !accessControlRequestMethodHeader.IsEmpty()) { response.SetStatus(200, "OK"); response.AddHeader("Access-Control-Allow-Methods", listener->GetAccessControlAllowMethods()); @@ -208,27 +213,27 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) response.WriteBody(msg.CStr(), msg.GetLength()); response.Finish(); - m_PendingRequests--; - - return; + return false; } } - String accept_header = request.Headers->Get("accept"); - - if (request.RequestMethod != "GET" && accept_header != "application/json") { + if (m_CurrentRequest.RequestMethod != "GET" && m_CurrentRequest.Headers->Get("accept") != "application/json") { response.SetStatus(400, "Wrong Accept header"); response.AddHeader("Content-Type", "text/html"); String msg = "

Accept header is missing or not set to 'application/json'.

"; response.WriteBody(msg.CStr(), msg.GetLength()); - } else if (!user) { + response.Finish(); + return false; + } + + if (!m_AuthenticatedUser) { Log(LogWarning, "HttpServerConnection") - << "Unauthorized request: " << request.RequestMethod << " " << requestUrl; + << "Unauthorized request: " << m_CurrentRequest.RequestMethod << " " << requestUrl; response.SetStatus(401, "Unauthorized"); response.AddHeader("WWW-Authenticate", "Basic realm=\"Icinga 2\""); - if (request.Headers->Get("accept") == "application/json") { + if (m_CurrentRequest.Headers->Get("accept") == "application/json") { Dictionary::Ptr result = new Dictionary({ { "error", 401 }, { "status", "Unauthorized. Please check your user credentials." } @@ -240,44 +245,25 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) String msg = "

Unauthorized. Please check your user credentials.

"; response.WriteBody(msg.CStr(), msg.GetLength()); } - } else { - bool res = true; - while (!request.CompleteBody) - res = request.ParseBody(m_Context, false); - if (!res) { - Log(LogCritical, "HttpServerConnection", "Failed to read body"); - Dictionary::Ptr result = new Dictionary({ - { "error", 400 }, - { "status", "Bad Request: Malformed body." } - }); - HttpUtility::SendJsonBody(response, nullptr, result); - } else { - try { - HttpHandler::ProcessRequest(user, request, response); - } catch (const std::exception& ex) { - Log(LogCritical, "HttpServerConnection") - << "Unhandled exception while processing Http request: " << DiagnosticInformation(ex); - response.SetStatus(503, "Unhandled exception"); - String errorInfo = DiagnosticInformation(ex); + response.Finish(); + return false; + } - if (request.Headers->Get("accept") == "application/json") { - Dictionary::Ptr result = new Dictionary({ - { "error", 503 }, - { "status", errorInfo } - }); + return true; +} - HttpUtility::SendJsonBody(response, nullptr, result); - } else { - response.AddHeader("Content-Type", "text/plain"); - response.WriteBody(errorInfo.CStr(), errorInfo.GetLength()); - } - } - } +void HttpServerConnection::ProcessMessageAsync(HttpRequest& request, HttpResponse& response, const ApiUser::Ptr& user) +{ + try { + HttpHandler::ProcessRequest(user, request, response); + } catch (const std::exception& ex) { + Log(LogCritical, "HttpServerConnection") + << "Unhandled exception while processing Http request: " << DiagnosticInformation(ex); + HttpUtility::SendJsonError(response, nullptr, 503, "Unhandled exception" , DiagnosticInformation(ex)); } response.Finish(); - m_PendingRequests--; } diff --git a/lib/remote/httpserverconnection.hpp b/lib/remote/httpserverconnection.hpp index f52110013..104df7509 100644 --- a/lib/remote/httpserverconnection.hpp +++ b/lib/remote/httpserverconnection.hpp @@ -21,6 +21,7 @@ #define HTTPSERVERCONNECTION_H #include "remote/httprequest.hpp" +#include "remote/httpresponse.hpp" #include "remote/apiuser.hpp" #include "base/tlsstream.hpp" #include "base/timer.hpp" @@ -51,6 +52,7 @@ public: private: ApiUser::Ptr m_ApiUser; + ApiUser::Ptr m_AuthenticatedUser; TlsStream::Ptr m_Stream; double m_Seen; HttpRequest m_CurrentRequest; @@ -67,7 +69,9 @@ private: static void TimeoutTimerHandler(); void CheckLiveness(); - void ProcessMessageAsync(HttpRequest& request); + bool ManageHeaders(HttpResponse& response); + + void ProcessMessageAsync(HttpRequest& request, HttpResponse& response, const ApiUser::Ptr&); }; }