diff --git a/lib/remote/httprequest.cpp b/lib/remote/httprequest.cpp index 11480d86a..0d6c4a618 100644 --- a/lib/remote/httprequest.cpp +++ b/lib/remote/httprequest.cpp @@ -25,134 +25,147 @@ using namespace icinga; HttpRequest::HttpRequest(Stream::Ptr stream) - : Complete(false), + : CompleteHeaders(false), + CompleteBody(false), ProtocolVersion(HttpVersion11), Headers(new Dictionary()), m_Stream(std::move(stream)), m_State(HttpRequestStart) { } -bool HttpRequest::Parse(StreamReadContext& src, bool may_wait) +bool HttpRequest::ParseHeader(StreamReadContext& src, bool may_wait) { if (!m_Stream) return false; - if (m_State != HttpRequestBody) { - String line; - StreamReadStatus srs = m_Stream->ReadLine(&line, src, may_wait); + if (m_State != HttpRequestStart && m_State != HttpRequestHeaders) + return false; - if (srs != StatusNewItem) { - if (src.Size > 512) - BOOST_THROW_EXCEPTION(std::invalid_argument("Line length for HTTP header exceeded")); + String line; + StreamReadStatus srs = m_Stream->ReadLine(&line, src, may_wait); + if (srs != StatusNewItem) { + if (src.Size > 512) + BOOST_THROW_EXCEPTION(std::invalid_argument("Line length for HTTP header exceeded")); + + return false; + } + + if (line.GetLength() > 512) + BOOST_THROW_EXCEPTION(std::invalid_argument("Line length for HTTP header exceeded")); + + if (m_State == HttpRequestStart) { + /* ignore trailing new-lines */ + if (line == "") + return true; + + std::vector<String> tokens = line.Split(" "); + Log(LogDebug, "HttpRequest") + << "line: " << line << ", tokens: " << tokens.size(); + if (tokens.size() != 3) + BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid HTTP request")); + + RequestMethod = tokens[0]; + RequestUrl = new class Url(tokens[1]); + + if (tokens[2] == "HTTP/1.0") + ProtocolVersion = HttpVersion10; + else if (tokens[2] == "HTTP/1.1") { + ProtocolVersion = HttpVersion11; + } else + BOOST_THROW_EXCEPTION(std::invalid_argument("Unsupported HTTP version")); + + m_State = HttpRequestHeaders; + return true; + } else { // m_State = HttpRequestHeaders + if (line == "") { + m_State = HttpRequestBody; + CompleteHeaders = true; + return true; + + } else { + if (Headers->GetLength() > 128) + BOOST_THROW_EXCEPTION(std::invalid_argument("Maximum number of HTTP request headers exceeded")); + + String::SizeType pos = line.FindFirstOf(":"); + if (pos == String::NPos) + BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid HTTP request")); + + String key = line.SubStr(0, pos).ToLower().Trim(); + String value = line.SubStr(pos + 1).Trim(); + Headers->Set(key, value); + + if (key == "x-http-method-override") + RequestMethod = value; + + return true; + } + } +} + +bool HttpRequest::ParseBody(StreamReadContext& src, bool may_wait) +{ + if (!m_Stream || m_State != HttpRequestBody) + return false; + + /* we're done if the request doesn't contain a message body */ + if (!Headers->Contains("content-length") && !Headers->Contains("transfer-encoding")) { + CompleteBody = true; + return true; + } else if (!m_Body) + m_Body = new FIFO(); + + if (CompleteBody) + return true; + + if (Headers->Get("transfer-encoding") == "chunked") { + if (!m_ChunkContext) + m_ChunkContext = std::make_shared<ChunkReadContext>(std::ref(src)); + + char *data; + size_t size; + StreamReadStatus srs = HttpChunkedEncoding::ReadChunkFromStream(m_Stream, &data, &size, *m_ChunkContext.get(), may_wait); + + if (srs != StatusNewItem) + return false; + + m_Body->Write(data, size); + + delete [] data; + + if (size == 0) { + CompleteBody = true; + return true; + } + } else { + if (src.Eof) + BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); + + if (src.MustRead) { + if (!src.FillFromStream(m_Stream, false)) { + src.Eof = true; + BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); + } + + src.MustRead = false; + } + + long length_indicator_signed = Convert::ToLong(Headers->Get("content-length")); + + if (length_indicator_signed < 0) + BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative.")); + + size_t length_indicator = length_indicator_signed; + + if (src.Size < length_indicator) { + src.MustRead = true; return false; } - if (line.GetLength() > 512) - BOOST_THROW_EXCEPTION(std::invalid_argument("Line length for HTTP header exceeded")); - - if (m_State == HttpRequestStart) { - /* ignore trailing new-lines */ - if (line == "") - return true; - - std::vector<String> tokens = line.Split(" "); - Log(LogDebug, "HttpRequest") - << "line: " << line << ", tokens: " << tokens.size(); - if (tokens.size() != 3) - BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid HTTP request")); - - RequestMethod = tokens[0]; - RequestUrl = new class Url(tokens[1]); - - if (tokens[2] == "HTTP/1.0") - ProtocolVersion = HttpVersion10; - else if (tokens[2] == "HTTP/1.1") { - ProtocolVersion = HttpVersion11; - } else - BOOST_THROW_EXCEPTION(std::invalid_argument("Unsupported HTTP version")); - - m_State = HttpRequestHeaders; - } else if (m_State == HttpRequestHeaders) { - if (line == "") { - m_State = HttpRequestBody; - - /* we're done if the request doesn't contain a message body */ - if (!Headers->Contains("content-length") && !Headers->Contains("transfer-encoding")) - Complete = true; - else - m_Body = new FIFO(); - - return true; - - } else { - if (Headers->GetLength() > 128) - BOOST_THROW_EXCEPTION(std::invalid_argument("Maximum number of HTTP request headers exceeded")); - - String::SizeType pos = line.FindFirstOf(":"); - if (pos == String::NPos) - BOOST_THROW_EXCEPTION(std::invalid_argument("Invalid HTTP request")); - - String key = line.SubStr(0, pos).ToLower().Trim(); - String value = line.SubStr(pos + 1).Trim(); - Headers->Set(key, value); - - if (key == "x-http-method-override") - RequestMethod = value; - } - } else { - VERIFY(!"Invalid HTTP request state."); - } - } else if (m_State == HttpRequestBody) { - if (Headers->Get("transfer-encoding") == "chunked") { - if (!m_ChunkContext) - m_ChunkContext = std::make_shared<ChunkReadContext>(std::ref(src)); - - char *data; - size_t size; - StreamReadStatus srs = HttpChunkedEncoding::ReadChunkFromStream(m_Stream, &data, &size, *m_ChunkContext.get(), may_wait); - - if (srs != StatusNewItem) - return false; - - m_Body->Write(data, size); - - delete [] data; - - if (size == 0) { - Complete = true; - return true; - } - } else { - if (src.Eof) - BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); - - if (src.MustRead) { - if (!src.FillFromStream(m_Stream, false)) { - src.Eof = true; - BOOST_THROW_EXCEPTION(std::invalid_argument("Unexpected EOF in HTTP body")); - } - - src.MustRead = false; - } - - long length_indicator_signed = Convert::ToLong(Headers->Get("content-length")); - - if (length_indicator_signed < 0) - BOOST_THROW_EXCEPTION(std::invalid_argument("Content-Length must not be negative.")); - - size_t length_indicator = length_indicator_signed; - - if (src.Size < length_indicator) { - src.MustRead = true; - return false; - } - - m_Body->Write(src.Buffer, length_indicator); - src.DropData(length_indicator); - Complete = true; - return true; - } + m_Body->Write(src.Buffer, length_indicator); + src.DropData(length_indicator); + CompleteBody = true; + return true; } return true; diff --git a/lib/remote/httprequest.hpp b/lib/remote/httprequest.hpp index b59917283..e8591474c 100644 --- a/lib/remote/httprequest.hpp +++ b/lib/remote/httprequest.hpp @@ -52,7 +52,8 @@ enum HttpRequestState struct HttpRequest { public: - bool Complete; + bool CompleteHeaders; + bool CompleteBody; String RequestMethod; Url::Ptr RequestUrl; @@ -62,7 +63,8 @@ public: HttpRequest(Stream::Ptr stream); - bool Parse(StreamReadContext& src, bool may_wait); + bool ParseHeader(StreamReadContext& src, bool may_wait); + bool ParseBody(StreamReadContext& src, bool may_wait); size_t ReadBody(char *data, size_t count); void AddHeader(const String& key, const String& value); diff --git a/lib/remote/httpserverconnection.cpp b/lib/remote/httpserverconnection.cpp index 39cb397a6..9e400da16 100644 --- a/lib/remote/httpserverconnection.cpp +++ b/lib/remote/httpserverconnection.cpp @@ -92,7 +92,7 @@ bool HttpServerConnection::ProcessMessage() bool res; try { - res = m_CurrentRequest.Parse(m_Context, false); + res = m_CurrentRequest.ParseHeader(m_Context, false); } catch (const std::invalid_argument& ex) { HttpResponse response(m_Stream, m_CurrentRequest); response.SetStatus(400, "Bad request"); @@ -113,7 +113,7 @@ bool HttpServerConnection::ProcessMessage() return false; } - if (m_CurrentRequest.Complete) { + if (m_CurrentRequest.CompleteHeaders) { m_RequestQueue.Enqueue(std::bind(&HttpServerConnection::ProcessMessageAsync, HttpServerConnection::Ptr(this), m_CurrentRequest)); @@ -241,25 +241,37 @@ void HttpServerConnection::ProcessMessageAsync(HttpRequest& request) response.WriteBody(msg.CStr(), msg.GetLength()); } } else { - try { - HttpHandler::ProcessRequest(user, request, response); - } catch (const std::exception& ex) { - Log(LogCritical, "HttpServerConnection") - << "Unhandled exception while processing Http request: " << DiagnosticInformation(ex); - response.SetStatus(503, "Unhandled exception"); + bool res = true; + while (!request.CompleteBody) + res = request.ParseBody(m_Context, false); + if (!res) { + Log(LogCritical, "HttpServerConnection", "Failed to read body"); + Dictionary::Ptr result = new Dictionary({ + { "error", 400 }, + { "status", "Bad Request: Malformed body." } + }); + HttpUtility::SendJsonBody(response, nullptr, result); + } else { + try { + HttpHandler::ProcessRequest(user, request, response); + } catch (const std::exception& ex) { + Log(LogCritical, "HttpServerConnection") + << "Unhandled exception while processing Http request: " << DiagnosticInformation(ex); + response.SetStatus(503, "Unhandled exception"); - String errorInfo = DiagnosticInformation(ex); + String errorInfo = DiagnosticInformation(ex); - if (request.Headers->Get("accept") == "application/json") { - Dictionary::Ptr result = new Dictionary({ - { "error", 503 }, - { "status", errorInfo } - }); + if (request.Headers->Get("accept") == "application/json") { + Dictionary::Ptr result = new Dictionary({ + { "error", 503 }, + { "status", errorInfo } + }); - HttpUtility::SendJsonBody(response, nullptr, result); - } else { - response.AddHeader("Content-Type", "text/plain"); - response.WriteBody(errorInfo.CStr(), errorInfo.GetLength()); + HttpUtility::SendJsonBody(response, nullptr, result); + } else { + response.AddHeader("Content-Type", "text/plain"); + response.WriteBody(errorInfo.CStr(), errorInfo.GetLength()); + } } } }