mirror of
				https://github.com/go-gitea/gitea.git
				synced 2025-11-04 13:34:43 +01:00 
			
		
		
		
	This PR fixes a panic issue in the WaitGroup that occurs when Gitea is shut down using Ctrl+C. It ensures that all active connection pointers in the server are properly tracked and forcibly closed when the hammer shutdown is invoked. The process remains graceful — the normal shutdown sequence runs before the hammer is triggered, and existing connections are given a timeout period to complete gracefully. This PR also fixes `no logger writer` problem. Now the log close will only be invoked when the command exit. - Fixes #35468 - Fixes #35551 - Fixes #35559 - Replace #35578 --------- Co-authored-by: wxiaoguang <wxiaoguang@gmail.com>
		
			
				
	
	
		
			294 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			294 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2019 The Gitea Authors. All rights reserved.
 | 
						|
// SPDX-License-Identifier: MIT
 | 
						|
 | 
						|
// This code is highly inspired by endless go
 | 
						|
 | 
						|
package graceful
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/tls"
 | 
						|
	"net"
 | 
						|
	"os"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"syscall"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"code.gitea.io/gitea/modules/log"
 | 
						|
	"code.gitea.io/gitea/modules/proxyprotocol"
 | 
						|
	"code.gitea.io/gitea/modules/setting"
 | 
						|
)
 | 
						|
 | 
						|
// GetListener returns a net listener
 | 
						|
// This determines the implementation of net.Listener which the server will use,
 | 
						|
// so that downstreams could provide their own Listener, such as with a hidden service or a p2p network
 | 
						|
var GetListener = DefaultGetListener
 | 
						|
 | 
						|
// ServeFunction represents a listen.Accept loop
 | 
						|
type ServeFunction = func(net.Listener) error
 | 
						|
 | 
						|
// Server represents our graceful server
 | 
						|
type Server struct {
 | 
						|
	network  string
 | 
						|
	address  string
 | 
						|
	listener net.Listener
 | 
						|
 | 
						|
	lock          sync.RWMutex
 | 
						|
	state         state
 | 
						|
	connCounter   int64
 | 
						|
	connEmptyCond *sync.Cond
 | 
						|
 | 
						|
	BeforeBegin          func(network, address string)
 | 
						|
	OnShutdown           func()
 | 
						|
	PerWriteTimeout      time.Duration
 | 
						|
	PerWritePerKbTimeout time.Duration
 | 
						|
}
 | 
						|
 | 
						|
// NewServer creates a server on network at provided address
 | 
						|
func NewServer(network, address, name string) *Server {
 | 
						|
	if GetManager().IsChild() {
 | 
						|
		log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
 | 
						|
	} else {
 | 
						|
		log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
 | 
						|
	}
 | 
						|
	srv := &Server{
 | 
						|
		state:                stateInit,
 | 
						|
		network:              network,
 | 
						|
		address:              address,
 | 
						|
		PerWriteTimeout:      setting.PerWriteTimeout,
 | 
						|
		PerWritePerKbTimeout: setting.PerWritePerKbTimeout,
 | 
						|
	}
 | 
						|
	srv.connEmptyCond = sync.NewCond(&srv.lock)
 | 
						|
 | 
						|
	srv.BeforeBegin = func(network, addr string) {
 | 
						|
		log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
 | 
						|
	}
 | 
						|
 | 
						|
	return srv
 | 
						|
}
 | 
						|
 | 
						|
// ListenAndServe listens on the provided network address and then calls Serve
 | 
						|
// to handle requests on incoming connections.
 | 
						|
func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error {
 | 
						|
	go srv.awaitShutdown()
 | 
						|
 | 
						|
	listener, err := GetListener(srv.network, srv.address)
 | 
						|
	if err != nil {
 | 
						|
		log.Error("Unable to GetListener: %v", err)
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// we need to wrap the listener to take account of our lifecycle
 | 
						|
	listener = newWrappedListener(listener, srv)
 | 
						|
 | 
						|
	// Now we need to take account of ProxyProtocol settings...
 | 
						|
	if useProxyProtocol {
 | 
						|
		listener = &proxyprotocol.Listener{
 | 
						|
			Listener:           listener,
 | 
						|
			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
 | 
						|
			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
 | 
						|
		}
 | 
						|
	}
 | 
						|
	srv.listener = listener
 | 
						|
 | 
						|
	srv.BeforeBegin(srv.network, srv.address)
 | 
						|
 | 
						|
	return srv.Serve(serve)
 | 
						|
}
 | 
						|
 | 
						|
// ListenAndServeTLSConfig listens on the provided network address and then calls
 | 
						|
// Serve to handle requests on incoming TLS connections.
 | 
						|
func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error {
 | 
						|
	go srv.awaitShutdown()
 | 
						|
 | 
						|
	if tlsConfig.MinVersion == 0 {
 | 
						|
		tlsConfig.MinVersion = tls.VersionTLS12
 | 
						|
	}
 | 
						|
 | 
						|
	listener, err := GetListener(srv.network, srv.address)
 | 
						|
	if err != nil {
 | 
						|
		log.Error("Unable to get Listener: %v", err)
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	// we need to wrap the listener to take account of our lifecycle
 | 
						|
	listener = newWrappedListener(listener, srv)
 | 
						|
 | 
						|
	// Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us
 | 
						|
	if useProxyProtocol && !proxyProtocolTLSBridging {
 | 
						|
		listener = &proxyprotocol.Listener{
 | 
						|
			Listener:           listener,
 | 
						|
			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
 | 
						|
			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Now handle the tls protocol
 | 
						|
	listener = tls.NewListener(listener, tlsConfig)
 | 
						|
 | 
						|
	// Now if we're bridging then we need the proxy to tell us who we're bridging for...
 | 
						|
	if useProxyProtocol && proxyProtocolTLSBridging {
 | 
						|
		listener = &proxyprotocol.Listener{
 | 
						|
			Listener:           listener,
 | 
						|
			ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
 | 
						|
			AcceptUnknown:      setting.ProxyProtocolAcceptUnknown,
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	srv.listener = listener
 | 
						|
	srv.BeforeBegin(srv.network, srv.address)
 | 
						|
 | 
						|
	return srv.Serve(serve)
 | 
						|
}
 | 
						|
 | 
						|
// Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
 | 
						|
// service goroutine for each. The service goroutines read requests and then call
 | 
						|
// handler to reply to them. Handler is typically nil, in which case the
 | 
						|
// DefaultServeMux is used.
 | 
						|
//
 | 
						|
// In addition to the standard Serve behaviour each connection is added to a
 | 
						|
// sync.Waitgroup so that all outstanding connections can be served before shutting
 | 
						|
// down the server.
 | 
						|
func (srv *Server) Serve(serve ServeFunction) error {
 | 
						|
	defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
 | 
						|
	srv.setState(stateRunning)
 | 
						|
	GetManager().RegisterServer()
 | 
						|
	err := serve(srv.listener)
 | 
						|
	log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
 | 
						|
	srv.waitForActiveConnections()
 | 
						|
	srv.setState(stateTerminate)
 | 
						|
	GetManager().ServerDone()
 | 
						|
	// use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
 | 
						|
	if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (srv *Server) getState() state {
 | 
						|
	srv.lock.RLock()
 | 
						|
	defer srv.lock.RUnlock()
 | 
						|
 | 
						|
	return srv.state
 | 
						|
}
 | 
						|
 | 
						|
func (srv *Server) setState(st state) {
 | 
						|
	srv.lock.Lock()
 | 
						|
	defer srv.lock.Unlock()
 | 
						|
 | 
						|
	srv.state = st
 | 
						|
}
 | 
						|
 | 
						|
func (srv *Server) waitForActiveConnections() {
 | 
						|
	srv.lock.Lock()
 | 
						|
	for srv.connCounter > 0 {
 | 
						|
		srv.connEmptyCond.Wait()
 | 
						|
	}
 | 
						|
	srv.lock.Unlock()
 | 
						|
}
 | 
						|
 | 
						|
func (srv *Server) wrapConnection(c net.Conn) (net.Conn, error) {
 | 
						|
	srv.lock.Lock()
 | 
						|
	defer srv.lock.Unlock()
 | 
						|
 | 
						|
	if srv.state != stateRunning {
 | 
						|
		_ = c.Close()
 | 
						|
		return nil, syscall.EINVAL // same as AcceptTCP
 | 
						|
	}
 | 
						|
 | 
						|
	srv.connCounter++
 | 
						|
	return &wrappedConn{Conn: c, server: srv}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (srv *Server) removeConnection(_ *wrappedConn) {
 | 
						|
	srv.lock.Lock()
 | 
						|
	defer srv.lock.Unlock()
 | 
						|
 | 
						|
	srv.connCounter--
 | 
						|
	if srv.connCounter <= 0 {
 | 
						|
		srv.connEmptyCond.Broadcast()
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// closeAllConnections forcefully closes all active connections
 | 
						|
func (srv *Server) closeAllConnections() {
 | 
						|
	srv.lock.Lock()
 | 
						|
	if srv.connCounter > 0 {
 | 
						|
		log.Warn("After graceful shutdown period, %d connections are still active. Forcefully close.", srv.connCounter)
 | 
						|
		srv.connCounter = 0 // OS will close all the connections after the process exits, so we just assume there is no active connection now
 | 
						|
	}
 | 
						|
	srv.lock.Unlock()
 | 
						|
	srv.connEmptyCond.Broadcast()
 | 
						|
}
 | 
						|
 | 
						|
type filer interface {
 | 
						|
	File() (*os.File, error)
 | 
						|
}
 | 
						|
 | 
						|
type wrappedListener struct {
 | 
						|
	net.Listener
 | 
						|
	server *Server
 | 
						|
}
 | 
						|
 | 
						|
var (
 | 
						|
	_ net.Listener = (*wrappedListener)(nil)
 | 
						|
	_ filer        = (*wrappedListener)(nil)
 | 
						|
)
 | 
						|
 | 
						|
func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
 | 
						|
	return &wrappedListener{
 | 
						|
		Listener: l,
 | 
						|
		server:   srv,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (wl *wrappedListener) Accept() (c net.Conn, err error) {
 | 
						|
	if tcl, ok := wl.Listener.(*net.TCPListener); ok {
 | 
						|
		// Set keepalive on TCPListeners connections if possible, see http.tcpKeepAliveListener
 | 
						|
		tc, err := tcl.AcceptTCP()
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		_ = tc.SetKeepAlive(true)
 | 
						|
		_ = tc.SetKeepAlivePeriod(3 * time.Minute)
 | 
						|
		c = tc
 | 
						|
	} else {
 | 
						|
		c, err = wl.Listener.Accept()
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return wl.server.wrapConnection(c)
 | 
						|
}
 | 
						|
 | 
						|
func (wl *wrappedListener) File() (*os.File, error) {
 | 
						|
	// returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
 | 
						|
	return wl.Listener.(filer).File()
 | 
						|
}
 | 
						|
 | 
						|
type wrappedConn struct {
 | 
						|
	net.Conn
 | 
						|
	server   *Server
 | 
						|
	deadline time.Time
 | 
						|
}
 | 
						|
 | 
						|
func (w *wrappedConn) Write(p []byte) (n int, err error) {
 | 
						|
	if w.server.PerWriteTimeout > 0 {
 | 
						|
		minTimeout := time.Duration(len(p)/1024) * w.server.PerWritePerKbTimeout
 | 
						|
		minDeadline := time.Now().Add(minTimeout).Add(w.server.PerWriteTimeout)
 | 
						|
 | 
						|
		w.deadline = w.deadline.Add(minTimeout)
 | 
						|
		if minDeadline.After(w.deadline) {
 | 
						|
			w.deadline = minDeadline
 | 
						|
		}
 | 
						|
		_ = w.Conn.SetWriteDeadline(w.deadline)
 | 
						|
	}
 | 
						|
	return w.Conn.Write(p)
 | 
						|
}
 | 
						|
 | 
						|
func (w *wrappedConn) Close() error {
 | 
						|
	w.server.removeConnection(w)
 | 
						|
	return w.Conn.Close()
 | 
						|
}
 |