From 6073e2f1bbf83d6f92b1ff696f75d7c8a4520ac8 Mon Sep 17 00:00:00 2001
From: wxiaoguang <wxiaoguang@gmail.com>
Date: Mon, 20 Jan 2025 14:25:17 +0800
Subject: [PATCH] Refactor response writer & access logger (#33323)

And add comments & tests
---
 routers/common/middleware.go        |  12 ++--
 services/context/access_log.go      | 103 ++++++++++++++++------------
 services/context/access_log_test.go |  75 ++++++++++++++++++++
 services/context/response.go        |  37 +++++-----
 4 files changed, 159 insertions(+), 68 deletions(-)
 create mode 100644 services/context/access_log_test.go

diff --git a/routers/common/middleware.go b/routers/common/middleware.go
index 12b0c67b01..de59b78583 100644
--- a/routers/common/middleware.go
+++ b/routers/common/middleware.go
@@ -43,14 +43,18 @@ func ProtocolMiddlewares() (handlers []any) {
 
 func RequestContextHandler() func(h http.Handler) http.Handler {
 	return func(next http.Handler) http.Handler {
-		return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
-			profDesc := fmt.Sprintf("%s: %s", req.Method, req.RequestURI)
+		return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) {
+			// this response writer might not be the same as the one in context.Base.Resp
+			// because there might be a "gzip writer" in the middle, so the "written size" here is the compressed size
+			respWriter := context.WrapResponseWriter(respOrig)
+
+			profDesc := fmt.Sprintf("HTTP: %s %s", req.Method, req.RequestURI)
 			ctx, finished := reqctx.NewRequestContext(req.Context(), profDesc)
 			defer finished()
 
 			defer func() {
 				if err := recover(); err != nil {
-					RenderPanicErrorPage(resp, req, err) // it should never panic
+					RenderPanicErrorPage(respWriter, req, err) // it should never panic
 				}
 			}()
 
@@ -62,7 +66,7 @@ func RequestContextHandler() func(h http.Handler) http.Handler {
 					_ = req.MultipartForm.RemoveAll() // remove the temp files buffered to tmp directory
 				}
 			})
-			next.ServeHTTP(context.WrapResponseWriter(resp), req)
+			next.ServeHTTP(respWriter, req)
 		})
 	}
 }
diff --git a/services/context/access_log.go b/services/context/access_log.go
index 0926748ac5..1985aae118 100644
--- a/services/context/access_log.go
+++ b/services/context/access_log.go
@@ -18,13 +18,14 @@ import (
 	"code.gitea.io/gitea/modules/web/middleware"
 )
 
-type routerLoggerOptions struct {
-	req            *http.Request
+type accessLoggerTmplData struct {
 	Identity       *string
 	Start          *time.Time
-	ResponseWriter http.ResponseWriter
-	Ctx            map[string]any
-	RequestID      *string
+	ResponseWriter struct {
+		Status, Size int
+	}
+	Ctx       map[string]any
+	RequestID *string
 }
 
 const keyOfRequestIDInTemplate = ".RequestID"
@@ -51,51 +52,65 @@ func parseRequestIDFromRequestHeader(req *http.Request) string {
 	return requestID
 }
 
+type accessLogRecorder struct {
+	logger        log.BaseLogger
+	logTemplate   *template.Template
+	needRequestID bool
+}
+
+func (lr *accessLogRecorder) record(start time.Time, respWriter ResponseWriter, req *http.Request) {
+	var requestID string
+	if lr.needRequestID {
+		requestID = parseRequestIDFromRequestHeader(req)
+	}
+
+	reqHost, _, err := net.SplitHostPort(req.RemoteAddr)
+	if err != nil {
+		reqHost = req.RemoteAddr
+	}
+
+	identity := "-"
+	data := middleware.GetContextData(req.Context())
+	if signedUser, ok := data[middleware.ContextDataKeySignedUser].(*user_model.User); ok {
+		identity = signedUser.Name
+	}
+	buf := bytes.NewBuffer([]byte{})
+	tmplData := accessLoggerTmplData{
+		Identity: &identity,
+		Start:    &start,
+		Ctx: map[string]any{
+			"RemoteAddr": req.RemoteAddr,
+			"RemoteHost": reqHost,
+			"Req":        req,
+		},
+		RequestID: &requestID,
+	}
+	tmplData.ResponseWriter.Status = respWriter.Status()
+	tmplData.ResponseWriter.Size = respWriter.WrittenSize()
+	err = lr.logTemplate.Execute(buf, tmplData)
+	if err != nil {
+		log.Error("Could not execute access logger template: %v", err.Error())
+	}
+
+	lr.logger.Log(1, log.INFO, "%s", buf.String())
+}
+
+func newAccessLogRecorder() *accessLogRecorder {
+	return &accessLogRecorder{
+		logger:        log.GetLogger("access"),
+		logTemplate:   template.Must(template.New("log").Parse(setting.Log.AccessLogTemplate)),
+		needRequestID: len(setting.Log.RequestIDHeaders) > 0 && strings.Contains(setting.Log.AccessLogTemplate, keyOfRequestIDInTemplate),
+	}
+}
+
 // AccessLogger returns a middleware to log access logger
 func AccessLogger() func(http.Handler) http.Handler {
-	logger := log.GetLogger("access")
-	needRequestID := len(setting.Log.RequestIDHeaders) > 0 && strings.Contains(setting.Log.AccessLogTemplate, keyOfRequestIDInTemplate)
-	logTemplate, _ := template.New("log").Parse(setting.Log.AccessLogTemplate)
+	recorder := newAccessLogRecorder()
 	return func(next http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 			start := time.Now()
-
-			var requestID string
-			if needRequestID {
-				requestID = parseRequestIDFromRequestHeader(req)
-			}
-
-			reqHost, _, err := net.SplitHostPort(req.RemoteAddr)
-			if err != nil {
-				reqHost = req.RemoteAddr
-			}
-
 			next.ServeHTTP(w, req)
-			rw := w.(ResponseWriter)
-
-			identity := "-"
-			data := middleware.GetContextData(req.Context())
-			if signedUser, ok := data[middleware.ContextDataKeySignedUser].(*user_model.User); ok {
-				identity = signedUser.Name
-			}
-			buf := bytes.NewBuffer([]byte{})
-			err = logTemplate.Execute(buf, routerLoggerOptions{
-				req:            req,
-				Identity:       &identity,
-				Start:          &start,
-				ResponseWriter: rw,
-				Ctx: map[string]any{
-					"RemoteAddr": req.RemoteAddr,
-					"RemoteHost": reqHost,
-					"Req":        req,
-				},
-				RequestID: &requestID,
-			})
-			if err != nil {
-				log.Error("Could not execute access logger template: %v", err.Error())
-			}
-
-			logger.Info("%s", buf.String())
+			recorder.record(start, w.(ResponseWriter), req)
 		})
 	}
 }
diff --git a/services/context/access_log_test.go b/services/context/access_log_test.go
new file mode 100644
index 0000000000..9aab918ae6
--- /dev/null
+++ b/services/context/access_log_test.go
@@ -0,0 +1,75 @@
+// Copyright 2025 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package context
+
+import (
+	"fmt"
+	"net/http"
+	"net/url"
+	"testing"
+	"time"
+
+	"code.gitea.io/gitea/modules/log"
+	"code.gitea.io/gitea/modules/setting"
+
+	"github.com/stretchr/testify/assert"
+)
+
+type testAccessLoggerMock struct {
+	logs []string
+}
+
+func (t *testAccessLoggerMock) Log(skip int, level log.Level, format string, v ...any) {
+	t.logs = append(t.logs, fmt.Sprintf(format, v...))
+}
+
+func (t *testAccessLoggerMock) GetLevel() log.Level {
+	return log.INFO
+}
+
+type testAccessLoggerResponseWriterMock struct{}
+
+func (t testAccessLoggerResponseWriterMock) Header() http.Header {
+	return nil
+}
+
+func (t testAccessLoggerResponseWriterMock) Before(f func(ResponseWriter)) {}
+
+func (t testAccessLoggerResponseWriterMock) WriteHeader(statusCode int) {}
+
+func (t testAccessLoggerResponseWriterMock) Write(bytes []byte) (int, error) {
+	return 0, nil
+}
+
+func (t testAccessLoggerResponseWriterMock) Flush() {}
+
+func (t testAccessLoggerResponseWriterMock) WrittenStatus() int {
+	return http.StatusOK
+}
+
+func (t testAccessLoggerResponseWriterMock) Status() int {
+	return t.WrittenStatus()
+}
+
+func (t testAccessLoggerResponseWriterMock) WrittenSize() int {
+	return 123123
+}
+
+func TestAccessLogger(t *testing.T) {
+	setting.Log.AccessLogTemplate = `{{.Ctx.RemoteHost}} - {{.Identity}} {{.Start.Format "[02/Jan/2006:15:04:05 -0700]" }} "{{.Ctx.Req.Method}} {{.Ctx.Req.URL.RequestURI}} {{.Ctx.Req.Proto}}" {{.ResponseWriter.Status}} {{.ResponseWriter.Size}} "{{.Ctx.Req.Referer}}" "{{.Ctx.Req.UserAgent}}"`
+	recorder := newAccessLogRecorder()
+	mockLogger := &testAccessLoggerMock{}
+	recorder.logger = mockLogger
+	req := &http.Request{
+		RemoteAddr: "remote-addr",
+		Method:     "GET",
+		Proto:      "https",
+		URL:        &url.URL{Path: "/path"},
+	}
+	req.Header = http.Header{}
+	req.Header.Add("Referer", "referer")
+	req.Header.Add("User-Agent", "user-agent")
+	recorder.record(time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), &testAccessLoggerResponseWriterMock{}, req)
+	assert.Equal(t, []string{`remote-addr - - [02/Jan/2000:03:04:05 +0000] "GET /path https" 200 123123 "referer" "user-agent"`}, mockLogger.logs)
+}
diff --git a/services/context/response.go b/services/context/response.go
index 2f271f211b..3e557a112e 100644
--- a/services/context/response.go
+++ b/services/context/response.go
@@ -15,27 +15,26 @@ type ResponseWriter interface {
 	http.Flusher
 	web_types.ResponseStatusProvider
 
-	Before(func(ResponseWriter))
-
-	Status() int // used by access logger template
-	Size() int   // used by access logger template
+	Before(fn func(ResponseWriter))
+	Status() int
+	WrittenSize() int
 }
 
-var _ ResponseWriter = &Response{}
+var _ ResponseWriter = (*Response)(nil)
 
 // Response represents a response
 type Response struct {
 	http.ResponseWriter
 	written        int
 	status         int
-	befores        []func(ResponseWriter)
+	beforeFuncs    []func(ResponseWriter)
 	beforeExecuted bool
 }
 
 // Write writes bytes to HTTP endpoint
 func (r *Response) Write(bs []byte) (int, error) {
 	if !r.beforeExecuted {
-		for _, before := range r.befores {
+		for _, before := range r.beforeFuncs {
 			before(r)
 		}
 		r.beforeExecuted = true
@@ -51,18 +50,14 @@ func (r *Response) Write(bs []byte) (int, error) {
 	return size, nil
 }
 
-func (r *Response) Status() int {
-	return r.status
-}
-
-func (r *Response) Size() int {
+func (r *Response) WrittenSize() int {
 	return r.written
 }
 
 // WriteHeader write status code
 func (r *Response) WriteHeader(statusCode int) {
 	if !r.beforeExecuted {
-		for _, before := range r.befores {
+		for _, before := range r.beforeFuncs {
 			before(r)
 		}
 		r.beforeExecuted = true
@@ -80,6 +75,12 @@ func (r *Response) Flush() {
 	}
 }
 
+// Status returns status code written
+// TODO: use WrittenStatus instead
+func (r *Response) Status() int {
+	return r.status
+}
+
 // WrittenStatus returned status code written
 func (r *Response) WrittenStatus() int {
 	return r.status
@@ -87,17 +88,13 @@ func (r *Response) WrittenStatus() int {
 
 // Before allows for a function to be called before the ResponseWriter has been written to. This is
 // useful for setting headers or any other operations that must happen before a response has been written.
-func (r *Response) Before(f func(ResponseWriter)) {
-	r.befores = append(r.befores, f)
+func (r *Response) Before(fn func(ResponseWriter)) {
+	r.beforeFuncs = append(r.beforeFuncs, fn)
 }
 
 func WrapResponseWriter(resp http.ResponseWriter) *Response {
 	if v, ok := resp.(*Response); ok {
 		return v
 	}
-	return &Response{
-		ResponseWriter: resp,
-		status:         0,
-		befores:        make([]func(ResponseWriter), 0),
-	}
+	return &Response{ResponseWriter: resp}
 }