Authenticate API user before parsing body

This commit is contained in:
Jean Flach 2018-02-08 14:54:52 +01:00
parent a46dc64e6a
commit aa1ccd7ada
6 changed files with 169 additions and 147 deletions

View File

@ -20,6 +20,7 @@
#include "remote/apiuser.hpp"
#include "remote/apiuser.tcpp"
#include "base/configtype.hpp"
#include "base/base64.hpp"
using namespace icinga;
@ -34,3 +35,30 @@ ApiUser::Ptr ApiUser::GetByClientCN(const String& cn)
return ApiUser::Ptr();
}
ApiUser::Ptr ApiUser::GetByAuthHeader(const String& auth_header) {
String::SizeType pos = auth_header.FindFirstOf(" ");
String username, password;
if (pos != String::NPos && auth_header.SubStr(0, pos) == "Basic") {
String credentials_base64 = auth_header.SubStr(pos + 1);
String credentials = Base64::Decode(credentials_base64);
String::SizeType cpos = credentials.FindFirstOf(":");
if (cpos != String::NPos) {
username = credentials.SubStr(0, cpos);
password = credentials.SubStr(cpos + 1);
}
}
const ApiUser::Ptr& user = ApiUser::GetByName(username);
/* Deny authentication if 1) given password is empty 2) configured password does not match. */
if (password.IsEmpty())
return nullptr;
else if (user && user->GetPassword() != password)
return nullptr;
return user;
}

View File

@ -36,6 +36,7 @@ public:
DECLARE_OBJECTNAME(ApiUser);
static ApiUser::Ptr GetByClientCN(const String& cn);
static ApiUser::Ptr GetByAuthHeader(const String& auth_header);
};
}

View File

@ -30,6 +30,7 @@ using namespace icinga;
HttpRequest::HttpRequest(const Stream::Ptr& stream)
: CompleteHeaders(false),
CompleteHeaderCheck(false),
CompleteBody(false),
ProtocolVersion(HttpVersion11),
Headers(new Dictionary()),
@ -43,7 +44,7 @@ bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait)
return false;
if (m_State != HttpRequestStart && m_State != HttpRequestHeaders)
return false;
BOOST_THROW_EXCEPTION(std::runtime_error("Invalid HTTP state"));
String line;
StreamReadStatus srs = m_Stream->ReadLine(&line, src, may_wait);
@ -109,19 +110,19 @@ bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait)
bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait)
{
if (!m_Stream || m_State != HttpRequestBody)
if (!m_Stream)
return false;
if (m_State != HttpRequestBody)
BOOST_THROW_EXCEPTION(std::runtime_error("Invalid HTTP state"));
/* we're done if the request doesn't contain a message body */
if (!Headers->Contains("content-length") && !Headers->Contains("transfer-encoding")) {
CompleteBody = true;
return true;
return false;
} else if (!m_Body)
m_Body = new FIFO();
if (CompleteBody)
return true;
if (Headers->Get("transfer-encoding") == "chunked") {
if (!m_ChunkContext)
m_ChunkContext = boost::make_shared<ChunkReadContext>(boost::ref(src));
@ -139,39 +140,38 @@ bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait)
if (size == 0) {
CompleteBody = true;
return true;
}
} else {
if (src.Eof)
BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body"));
if (src.MustRead) {
if (!src.FillFromStream(m_Stream, false)) {
src.Eof = true;
BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body"));
}
src.MustRead = false;
}
long length_indicator_signed = Convert::ToLong(Headers->Get("content-length"));
if (length_indicator_signed < 0)
BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative."));
size_t length_indicator = length_indicator_signed;
if (src.Size < length_indicator) {
src.MustRead = true;
return false;
}
m_Body->Write(src.Buffer, length_indicator);
src.DropData(length_indicator);
CompleteBody = true;
return true;
} else
return true;
}
if (src.Eof)
BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body"));
if (src.MustRead) {
if (!src.FillFromStream(m_Stream, false)) {
src.Eof = true;
BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body"));
}
src.MustRead = false;
}
long length_indicator_signed = Convert::ToLong(Headers->Get("content-length"));
if (length_indicator_signed < 0)
BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative."));
size_t length_indicator = length_indicator_signed;
if (src.Size < length_indicator) {
src.MustRead = true;
return false;
}
m_Body->Write(src.Buffer, length_indicator);
src.DropData(length_indicator);
CompleteBody = true;
return true;
}

View File

@ -53,6 +53,7 @@ struct I2_REMOTE_API HttpRequest
{
public:
bool CompleteHeaders;
bool CompleteHeaderCheck;
bool CompleteBody;
String RequestMethod;

View File

@ -90,100 +90,105 @@ void HttpServerConnection::Disconnect(void)
bool HttpServerConnection::ProcessMessage(void)
{
bool res;
HttpResponse response(m_Stream, m_CurrentRequest);
try {
res = m_CurrentRequest.ParseHeader(m_Context, false);
} catch (const std::invalid_argument& ex) {
HttpResponse response(m_Stream, m_CurrentRequest);
response.SetStatus(400, "Bad request");
String msg = String("<h1>Bad request</h1><p><pre>") + ex.what() + "</pre></p>";
response.WriteBody(msg.CStr(), msg.GetLength());
response.Finish();
if (!m_CurrentRequest.CompleteHeaders) {
try {
res = m_CurrentRequest.ParseHeader(m_Context, false);
} catch (const std::invalid_argument& ex) {
response.SetStatus(400, "Bad Request");
String msg = String("<h1>Bad Request</h1><p><pre>") + ex.what() + "</pre></p>";
response.WriteBody(msg.CStr(), msg.GetLength());
response.Finish();
m_Stream->Shutdown();
return false;
} catch (const std::exception& ex) {
HttpResponse response(m_Stream, m_CurrentRequest);
response.SetStatus(400, "Bad request");
String msg = "<h1>Bad request</h1><p><pre>" + DiagnosticInformation(ex) + "</pre></p>";
response.WriteBody(msg.CStr(), msg.GetLength());
response.Finish();
m_Stream->Shutdown();
return false;
} catch (const std::exception& ex) {
response.SetStatus(500, "Internal Server Error");
String msg = "<h1>Internal Server Error</h1><p><pre>" + DiagnosticInformation(ex) + "</pre></p>";
response.WriteBody(msg.CStr(), msg.GetLength());
response.Finish();
m_Stream->Shutdown();
return false;
m_Stream->Shutdown();
return false;
}
return res;
}
if (m_CurrentRequest.CompleteHeaders) {
m_RequestQueue.Enqueue(boost::bind(&HttpServerConnection::ProcessMessageAsync,
HttpServerConnection::Ptr(this), m_CurrentRequest));
m_Seen = Utility::GetTime();
m_PendingRequests++;
m_CurrentRequest.~HttpRequest();
new (&m_CurrentRequest) HttpRequest(m_Stream);
return true;
}
return res;
}
void HttpServerConnection::ProcessMessageAsync(HttpRequest& request)
{
String auth_header = request.Headers->Get("authorization");
String::SizeType pos = auth_header.FindFirstOf(" ");
String username, password;
if (pos != String::NPos && auth_header.SubStr(0, pos) == "Basic") {
String credentials_base64 = auth_header.SubStr(pos + 1);
String credentials = Base64::Decode(credentials_base64);
String::SizeType cpos = credentials.FindFirstOf(":");
if (cpos != String::NPos) {
username = credentials.SubStr(0, cpos);
password = credentials.SubStr(cpos + 1);
if (!m_CurrentRequest.CompleteHeaderCheck) {
m_CurrentRequest.CompleteHeaderCheck = true;
if (!ManageHeaders(response)) {
m_Stream->Shutdown();
return false;
}
}
ApiUser::Ptr user;
if (!m_CurrentRequest.CompleteBody) {
try {
res = m_CurrentRequest.ParseBody(m_Context, false);
} catch (const std::invalid_argument& ex) {
response.SetStatus(400, "Bad Request");
String msg = String("<h1>Bad Request</h1><p><pre>") + ex.what() + "</pre></p>";
response.WriteBody(msg.CStr(), msg.GetLength());
response.Finish();
m_Stream->Shutdown();
return false;
} catch (const std::exception& ex) {
response.SetStatus(500, "Internal Server Error");
String msg = "<h1>Internal Server Error</h1><p><pre>" + DiagnosticInformation(ex) + "</pre></p>";
response.WriteBody(msg.CStr(), msg.GetLength());
response.Finish();
m_Stream->Shutdown();
return false;
}
return res;
}
m_RequestQueue.Enqueue(std::bind(&HttpServerConnection::ProcessMessageAsync,
HttpServerConnection::Ptr(this), m_CurrentRequest, response, m_AuthenticatedUser));
m_Seen = Utility::GetTime();
m_PendingRequests++;
m_CurrentRequest.~HttpRequest();
new (&m_CurrentRequest) HttpRequest(m_Stream);
return false;
}
bool HttpServerConnection::ManageHeaders(HttpResponse& response)
{
if (m_CurrentRequest.Headers->Get("expect") == "100-continue") {
String continueResponse = "HTTP/1.1 100 Continue\r\n\r\n";
m_Stream->Write(continueResponse.CStr(), continueResponse.GetLength());
}
/* client_cn matched. */
if (m_ApiUser)
user = m_ApiUser;
else {
user = ApiUser::GetByName(username);
m_AuthenticatedUser = m_ApiUser;
else
m_AuthenticatedUser = ApiUser::GetByAuthHeader(m_CurrentRequest.Headers->Get("authorization"));
/* Deny authentication if 1) given password is empty 2) configured password does not match. */
if (password.IsEmpty())
user.reset();
else if (user && user->GetPassword() != password)
user.reset();
}
String requestUrl = request.RequestUrl->Format();
String requestUrl = m_CurrentRequest.RequestUrl->Format();
Socket::Ptr socket = m_Stream->GetSocket();
Log(LogInformation, "HttpServerConnection")
<< "Request: " << request.RequestMethod << " " << requestUrl
<< "Request: " << m_CurrentRequest.RequestMethod << " " << requestUrl
<< " (from " << (socket ? socket->GetPeerAddress() : "<unkown>")
<< ", user: " << (user ? user->GetName() : "<unauthenticated>") << ")";
HttpResponse response(m_Stream, request);
<< ", user: " << (m_AuthenticatedUser ? m_AuthenticatedUser->GetName() : "<unauthenticated>") << ")";
ApiListener::Ptr listener = ApiListener::GetInstance();
if (!listener)
return;
return false;
Array::Ptr headerAllowOrigin = listener->GetAccessControlAllowOrigin();
if (headerAllowOrigin->GetLength() != 0) {
String origin = request.Headers->Get("origin");
String origin = m_CurrentRequest.Headers->Get("origin");
{
ObjectLock olock(headerAllowOrigin);
@ -196,9 +201,9 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request)
if (listener->GetAccessControlAllowCredentials())
response.AddHeader("Access-Control-Allow-Credentials", "true");
String accessControlRequestMethodHeader = request.Headers->Get("access-control-request-method");
String accessControlRequestMethodHeader = m_CurrentRequest.Headers->Get("access-control-request-method");
if (!accessControlRequestMethodHeader.IsEmpty()) {
if (m_CurrentRequest.RequestMethod == "OPTIONS" && !accessControlRequestMethodHeader.IsEmpty()) {
response.SetStatus(200, "OK");
response.AddHeader("Access-Control-Allow-Methods", listener->GetAccessControlAllowMethods());
@ -208,27 +213,27 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request)
response.WriteBody(msg.CStr(), msg.GetLength());
response.Finish();
m_PendingRequests--;
return;
return false;
}
}
String accept_header = request.Headers->Get("accept");
if (request.RequestMethod != "GET" && accept_header != "application/json") {
if (m_CurrentRequest.RequestMethod != "GET" && m_CurrentRequest.Headers->Get("accept") != "application/json") {
response.SetStatus(400, "Wrong Accept header");
response.AddHeader("Content-Type", "text/html");
String msg = "<h1>Accept header is missing or not set to 'application/json'.</h1>";
response.WriteBody(msg.CStr(), msg.GetLength());
} else if (!user) {
response.Finish();
return false;
}
if (!m_AuthenticatedUser) {
Log(LogWarning, "HttpServerConnection")
<< "Unauthorized request: " << request.RequestMethod << " " << requestUrl;
<< "Unauthorized request: " << m_CurrentRequest.RequestMethod << " " << requestUrl;
response.SetStatus(401, "Unauthorized");
response.AddHeader("WWW-Authenticate", "Basic realm=\"Icinga 2\"");
if (request.Headers->Get("accept") == "application/json") {
if (m_CurrentRequest.Headers->Get("accept") == "application/json") {
Dictionary::Ptr result = new Dictionary();
result->Set("error", 401);
@ -240,42 +245,25 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request)
String msg = "<h1>Unauthorized. Please check your user credentials.</h1>";
response.WriteBody(msg.CStr(), msg.GetLength());
}
} else {
bool res = true;
while (!request.CompleteBody)
res = request.ParseBody(m_Context, false);
if (!res) {
Log(LogCritical, "HttpServerConnection", "Failed to read body");
Dictionary::Ptr result = new Dictionary;
result->Set("error", 404);
result->Set("status", "Bad Request: Malformed body.");
HttpUtility::SendJsonBody(response, result);
} else {
try {
HttpHandler::ProcessRequest(user, request, response);
} catch (const std::exception& ex) {
Log(LogCritical, "HttpServerConnection")
<< "Unhandled exception while processing Http request: " << DiagnosticInformation(ex);
response.SetStatus(503, "Unhandled exception");
String errorInfo = DiagnosticInformation(ex);
response.Finish();
return false;
}
if (request.Headers->Get("accept") == "application/json") {
Dictionary::Ptr result = new Dictionary();
result->Set("error", 503);
result->Set("status", errorInfo);
return true;
}
HttpUtility::SendJsonBody(response, result);
} else {
response.AddHeader("Content-Type", "text/plain");
response.WriteBody(errorInfo.CStr(), errorInfo.GetLength());
}
}
}
void HttpServerConnection::ProcessMessageAsync(HttpRequest& request, HttpResponse& response, const ApiUser::Ptr& user)
{
try {
HttpHandler::ProcessRequest(user, request, response);
} catch (const std::exception& ex) {
Log(LogCritical, "HttpServerConnection")
<< "Unhandled exception while processing Http request: " << DiagnosticInformation(ex);
HttpUtility::SendJsonError(response, 503, "Unhandled exception" , DiagnosticInformation(ex));
}
response.Finish();
m_PendingRequests--;
}

View File

@ -21,6 +21,7 @@
#define HTTPSERVERCONNECTION_H
#include "remote/httprequest.hpp"
#include "remote/httpresponse.hpp"
#include "remote/apiuser.hpp"
#include "base/tlsstream.hpp"
#include "base/timer.hpp"
@ -51,6 +52,7 @@ public:
private:
ApiUser::Ptr m_ApiUser;
ApiUser::Ptr m_AuthenticatedUser;
TlsStream::Ptr m_Stream;
double m_Seen;
HttpRequest m_CurrentRequest;
@ -67,7 +69,9 @@ private:
static void TimeoutTimerHandler(void);
void CheckLiveness(void);
void ProcessMessageAsync(HttpRequest& request);
bool ManageHeaders(HttpResponse& response);
void ProcessMessageAsync(HttpRequest& request, HttpResponse& response, const ApiUser::Ptr&);
};
}