diff --git a/config/config.go b/config/config.go index a8f94c4e..a684443e 100644 --- a/config/config.go +++ b/config/config.go @@ -263,6 +263,11 @@ type Server struct { // Optional Proxy configuration Proxy Proxy `yaml:"proxy,omitempty"` + // Graceful shutdown timeout + // Maximum time to wait for active connections to complete during shutdown. + // Default is 25s + GracefulShutdownTimeout Duration `yaml:"graceful_shutdown_timeout,omitempty"` + // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` } diff --git a/docs/src/content/docs/cn/index.md b/docs/src/content/docs/cn/index.md index 525d7ca6..a2dac85c 100644 --- a/docs/src/content/docs/cn/index.md +++ b/docs/src/content/docs/cn/index.md @@ -39,6 +39,8 @@ server: http: listen_addr: ":9090" allowed_networks: ["127.0.0.0/24"] + # 优雅关闭时等待活动连接完成的最长时间。 + graceful_shutdown_timeout: 25s users: - name: "default" diff --git a/docs/src/content/docs/configuration/default.md b/docs/src/content/docs/configuration/default.md index 1076d72b..730111ad 100644 --- a/docs/src/content/docs/configuration/default.md +++ b/docs/src/content/docs/configuration/default.md @@ -190,6 +190,10 @@ server: enable: true header: CF-Connecting-IP + # Maximum time to wait for active connections to complete during shutdown. + # Default is 25s. + graceful_shutdown_timeout: 25s + # Configs for input users. users: # Name and password are used to authorize access via BasicAuth or diff --git a/docs/src/content/docs/index.md b/docs/src/content/docs/index.md index 446ac69e..cf3b4e31 100644 --- a/docs/src/content/docs/index.md +++ b/docs/src/content/docs/index.md @@ -32,6 +32,8 @@ server: http: listen_addr: ":9090" allowed_networks: ["127.0.0.0/24"] + # Maximum time to wait for active connections to complete during shutdown. + graceful_shutdown_timeout: 25s users: - name: "default" diff --git a/main.go b/main.go index 84f594dd..ce0fd1ff 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "context" "crypto/tls" + "errors" "flag" "fmt" "net" @@ -10,6 +11,7 @@ import ( "os" "os/signal" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -38,6 +40,12 @@ var ( allowedNetworksMetrics atomic.Value proxyHandler atomic.Value allowPing atomic.Bool + + // gracefulShutdownTimeout stores the configured shutdown timeout + gracefulShutdownTimeout time.Duration + + // activeConnections tracks the number of currently active HTTP connections + activeConnections atomic.Int64 ) func main() { @@ -75,16 +83,35 @@ func main() { autocertManager = newAutocertManager(server.HTTPS.Autocert) } + gracefulShutdownTimeout = getGracefulShutdownTimeout(server.GracefulShutdownTimeout) + log.Infof("Graceful shutdown timeout: %s", gracefulShutdownTimeout) + notifyReady() + var httpServer, httpsServer *http.Server + serverErrors := make(chan error, 2) + if len(server.HTTPS.ListenAddr) != 0 { - go serveTLS(server.HTTPS) + httpsServer = startTLSServer(server.HTTPS, serverErrors) } if len(server.HTTP.ListenAddr) != 0 { - go serve(server.HTTP) + httpServer = startHTTPServer(server.HTTP, serverErrors) } - select {} + if err := waitForShutdownSignal(httpServer, httpsServer, serverErrors); err != nil { + log.Errorf("Shutdown error: %s", err) + os.Exit(1) + } +} + +// getGracefulShutdownTimeout returns the graceful shutdown timeout from config. +func getGracefulShutdownTimeout(configTimeout config.Duration) time.Duration { + const defaultTimeout = 25 * time.Second + + if configTimeout > 0 { + return time.Duration(configTimeout) + } + return defaultTimeout } func notifyReady() { @@ -113,6 +140,109 @@ func setupReloadConfigWatch() { }() } +// waitForShutdownSignal waits for SIGTERM or SIGINT and performs graceful shutdown +func waitForShutdownSignal(httpServer, httpsServer *http.Server, serverErrors <-chan error) error { + sigCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + defer stop() + + select { + case err := <-serverErrors: + return fmt.Errorf("server error: %w", err) + case <-sigCtx.Done(): + log.Infof("Shutdown signal received") + return gracefulShutdown(httpServer, httpsServer, serverErrors) + } + + return nil +} + +// gracefulShutdown performs graceful shutdown of HTTP servers +func gracefulShutdown(httpServer, httpsServer *http.Server, serverErrors <-chan error) error { + initialConns := activeConnections.Load() + log.Infof("Starting graceful shutdown with %d open connections", initialConns) + + // Create shutdown deadline + ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout) + defer cancel() + + go func() { + for { + select { + case err, ok := <-serverErrors: + if !ok { + return + } + if err != nil { + log.Errorf("Server error during shutdown: %s", err) + } + case <-ctx.Done(): + return + } + } + }() + + // Shutdown servers concurrently + var wg sync.WaitGroup + var httpErr, httpsErr error + + shutdownServer := func(s *http.Server, label string, errp *error) { + if s == nil { + return + } + wg.Add(1) + go func() { + defer wg.Done() + log.Infof("Shutting down %s server...", label) + if err := s.Shutdown(ctx); err != nil { + *errp = fmt.Errorf("%s shutdown: %w", label, err) + } else { + log.Infof("%s server stopped", label) + } + }() + } + shutdownServer(httpServer, "HTTP", &httpErr) + shutdownServer(httpsServer, "HTTPS", &httpsErr) + + // Signal channel for shutdown completion + shutdownComplete := make(chan struct{}) + go func() { + wg.Wait() + + // Clean up proxy resources + if proxy != nil { + log.Infof("Closing proxy resources...") + if err := proxy.close(); err != nil { + log.Errorf("Proxy close error: %s", err) + } + } + close(shutdownComplete) + }() + + // Wait for shutdown to complete or timeout + select { + case <-shutdownComplete: + finalConns := activeConnections.Load() + joinedErrs := errors.Join(httpErr, httpsErr) + + if joinedErrs != nil { + return fmt.Errorf("shutdown completed with errors (remaining open connections: %d): %w", finalConns, joinedErrs) + } + if finalConns > 0 { + log.Errorf("Graceful shutdown completed with %d open connections still active", finalConns) + } else if initialConns > 0 { + log.Infof("Graceful shutdown completed successfully (all connections closed)") + } else { + log.Infof("Graceful shutdown completed successfully") + } + return nil + case <-ctx.Done(): + remainingConns := activeConnections.Load() + return fmt.Errorf("shutdown timeout exceeded with %d open connections still active", remainingConns) + } + + return nil +} + var autocertManager *autocert.Manager func newAutocertManager(cfg config.Autocert) *autocert.Manager { @@ -154,7 +284,8 @@ func newListener(listenAddr string) net.Listener { return ln } -func serveTLS(cfg config.HTTPS) { +// startTLSServer starts the HTTPS server and returns the server instance for graceful shutdown +func startTLSServer(cfg config.HTTPS, serverErrors chan<- error) *http.Server { ln := newListener(cfg.ListenAddr) h := http.HandlerFunc(serveHTTP) @@ -164,13 +295,21 @@ func serveTLS(cfg config.HTTPS) { log.Fatalf("cannot build TLS config: %s", err) } tln := tls.NewListener(ln, tlsCfg) + + server := newServer(tln, h, cfg.TimeoutCfg) log.Infof("Serving https on %q", cfg.ListenAddr) - if err := listenAndServe(tln, h, cfg.TimeoutCfg); err != nil { - log.Fatalf("TLS server error on %q: %s", cfg.ListenAddr, err) - } + + go func() { + if err := server.Serve(tln); err != nil && err != http.ErrServerClosed { + serverErrors <- fmt.Errorf("TLS server error on %q: %w", cfg.ListenAddr, err) + } + }() + + return server } -func serve(cfg config.HTTP) { +// startHTTPServer starts the HTTP server and returns the server instance for graceful shutdown +func startHTTPServer(cfg config.HTTP, serverErrors chan<- error) *http.Server { var h http.Handler ln := newListener(cfg.ListenAddr) @@ -187,10 +326,17 @@ func serve(cfg config.HTTP) { } h = autocertManager.HTTPHandler(h) } + + server := newServer(ln, h, cfg.TimeoutCfg) log.Infof("Serving http on %q", cfg.ListenAddr) - if err := listenAndServe(ln, h, cfg.TimeoutCfg); err != nil { - log.Fatalf("HTTP server error on %q: %s", cfg.ListenAddr, err) - } + + go func() { + if err := server.Serve(ln); err != nil && err != http.ErrServerClosed { + serverErrors <- fmt.Errorf("HTTP server error on %q: %w", cfg.ListenAddr, err) + } + }() + + return server } func newServer(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) *http.Server { @@ -201,17 +347,22 @@ func newServer(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) *http.Ser ReadTimeout: time.Duration(cfg.ReadTimeout), WriteTimeout: time.Duration(cfg.WriteTimeout), IdleTimeout: time.Duration(cfg.IdleTimeout), + ConnState: func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateNew: + activeConnections.Add(1) + log.Debugf("Connection opened from %s (active: %d)", conn.RemoteAddr(), activeConnections.Load()) + case http.StateClosed, http.StateHijacked: + activeConnections.Add(-1) + log.Debugf("Connection closed from %s (active: %d)", conn.RemoteAddr(), activeConnections.Load()) + } + }, // Suppress error logging from the server, since chproxy // must handle all these errors in the code. ErrorLog: log.NilLogger, } } -func listenAndServe(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) error { - s := newServer(ln, h, cfg) - return s.Serve(ln) -} - var promHandler = promhttp.Handler() //nolint:cyclop //TODO reduce complexity here. diff --git a/main_test.go b/main_test.go index fb0b414d..875bda3e 100644 --- a/main_test.go +++ b/main_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "net/url" "os" + "os/exec" "regexp" "strconv" "strings" @@ -1084,6 +1085,156 @@ func startHTTP() (*http.Server, chan struct{}) { return s, done } +func TestGracefulShutdownWaitsForRequest(t *testing.T) { + if os.Getenv("CHPROXY_TEST_GRACEFUL_SHUTDOWN") == "1" { + activeConnections.Store(0) + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + fmt.Fprintf(os.Stderr, "cannot listen: %s\n", err) + os.Exit(2) + } + + started := make(chan struct{}) + unblock := make(chan struct{}) + s := newServer(ln, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + close(started) + <-unblock + w.WriteHeader(http.StatusOK) + }), config.TimeoutCfg{}) + go s.Serve(ln) + + reqDone := make(chan struct{}) + go func() { + resp, err := http.Get("http://" + ln.Addr().String()) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(2) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + fmt.Fprintf(os.Stderr, "unexpected response status: %d\n", resp.StatusCode) + os.Exit(2) + } + close(reqDone) + }() + + select { + case <-started: + case <-time.After(500 * time.Millisecond): + fmt.Fprintln(os.Stderr, "handler did not start in time") + os.Exit(2) + } + + if activeConnections.Load() == 0 { + fmt.Fprintln(os.Stderr, "active connections not tracked") + os.Exit(2) + } + + shutdownStart := time.Now() + go func() { + time.Sleep(200 * time.Millisecond) + close(unblock) + }() + + gracefulShutdownTimeout = 2 * time.Second + if err := gracefulShutdown(s, nil, nil); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(2) + } + + shutdownDuration := time.Since(shutdownStart) + if shutdownDuration < 200*time.Millisecond || shutdownDuration > 3*time.Second { + fmt.Fprintf(os.Stderr, "shutdown took %v, expected ~200ms\n", shutdownDuration) + os.Exit(2) + } + + select { + case <-reqDone: + case <-time.After(1 * time.Second): + fmt.Fprintln(os.Stderr, "request did not complete") + os.Exit(2) + } + + if activeConnections.Load() != 0 { + fmt.Fprintf(os.Stderr, "active connections not cleaned up: %d\n", activeConnections.Load()) + os.Exit(2) + } + + return + } + + cmd := exec.Command(os.Args[0], "-test.run", "TestGracefulShutdownWaitsForRequest") + cmd.Env = append(os.Environ(), "CHPROXY_TEST_GRACEFUL_SHUTDOWN=1") + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("graceful shutdown child failed: %v\n%s", err, output) + } +} + +func TestGracefulShutdownTimesOut(t *testing.T) { + if os.Getenv("CHPROXY_TEST_GRACEFUL_TIMEOUT") == "1" { + activeConnections.Store(0) + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + fmt.Fprintf(os.Stderr, "cannot listen: %s\n", err) + os.Exit(2) + } + + started := make(chan struct{}) + s := newServer(ln, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + close(started) + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + }), config.TimeoutCfg{}) + go s.Serve(ln) + + go func() { + resp, err := http.Get("http://" + ln.Addr().String()) + if err != nil { + return + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + + select { + case <-started: + case <-time.After(500 * time.Millisecond): + fmt.Fprintln(os.Stderr, "handler did not start in time") + os.Exit(2) + } + + shutdownStart := time.Now() + gracefulShutdownTimeout = 500 * time.Millisecond + err = gracefulShutdown(s, nil, nil) + + shutdownDuration := time.Since(shutdownStart) + if shutdownDuration < 500*time.Millisecond || shutdownDuration > 3*time.Second { + fmt.Fprintf(os.Stderr, "shutdown took %v, expected ~500ms\n", shutdownDuration) + os.Exit(2) + } + + if err == nil { + fmt.Fprintln(os.Stderr, "expected timeout error, got nil") + os.Exit(2) + } + if !strings.Contains(err.Error(), "timeout") { + fmt.Fprintf(os.Stderr, "expected timeout error, got: %v\n", err) + os.Exit(2) + } + + return + } + + cmd := exec.Command(os.Args[0], "-test.run", "TestGracefulShutdownTimesOut") + cmd.Env = append(os.Environ(), "CHPROXY_TEST_GRACEFUL_TIMEOUT=1") + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("graceful shutdown timeout child failed: %v\n%s", err, output) + } +} + // TODO randomise port for each instance of the mock func startRedis(t *testing.T) *miniredis.Miniredis { redis := miniredis.NewMiniRedis() diff --git a/proxy.go b/proxy.go index 72ec90bb..151fa0de 100644 --- a/proxy.go +++ b/proxy.go @@ -710,6 +710,43 @@ func (rp *reverseProxy) applyConfig(cfg *config.Config) error { return nil } +// close performs cleanup of proxy resources during graceful shutdown +func (rp *reverseProxy) close() error { + rp.configLock.Lock() + defer rp.configLock.Unlock() + + rp.lock.RLock() + caches := rp.caches + rp.lock.RUnlock() + + // Close all caches + for name, c := range caches { + if err := c.Close(); err != nil { + log.Errorf("error closing cache %q: %s", name, err) + } + } + + // Signal heartbeat goroutines to stop + if rp.reloadSignal != nil { + select { + case <-rp.reloadSignal: + // Channel already closed, nothing to do + default: + close(rp.reloadSignal) + } + } + + // Wait for heartbeat goroutines to exit + rp.reloadWG.Wait() + + // Close idle connections in the transport + if transport, ok := rp.rp.Transport.(*http.Transport); ok { + transport.CloseIdleConnections() + } + + return nil +} + func initTempCaches(caches map[string]*cache.AsyncCache, transactionsTimeout config.Duration, cfg []config.Cache) error { for _, cc := range cfg { if _, ok := caches[cc.Name]; ok {