Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ type Server struct {
// Optional Proxy configuration
Proxy Proxy `yaml:"proxy,omitempty"`

// Graceful shutdown timeout
// Maximum time to wait for active connections to complete during shutdown.
// Default is 25s
GracefulShutdownTimeout Duration `yaml:"graceful_shutdown_timeout,omitempty"`

// Catches all undefined fields
XXX map[string]interface{} `yaml:",inline"`
}
Expand Down
2 changes: 2 additions & 0 deletions docs/src/content/docs/cn/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ server:
http:
listen_addr: ":9090"
allowed_networks: ["127.0.0.0/24"]
# 优雅关闭时等待活动连接完成的最长时间。
graceful_shutdown_timeout: 25s

users:
- name: "default"
Expand Down
4 changes: 4 additions & 0 deletions docs/src/content/docs/configuration/default.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ server:
enable: true
header: CF-Connecting-IP

# Maximum time to wait for active connections to complete during shutdown.
# Default is 25s.
graceful_shutdown_timeout: 25s

# Configs for input users.
users:
# Name and password are used to authorize access via BasicAuth or
Expand Down
2 changes: 2 additions & 0 deletions docs/src/content/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ server:
http:
listen_addr: ":9090"
allowed_networks: ["127.0.0.0/24"]
# Maximum time to wait for active connections to complete during shutdown.
graceful_shutdown_timeout: 25s

users:
- name: "default"
Expand Down
183 changes: 167 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package main
import (
"context"
"crypto/tls"
"errors"
"flag"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
Expand Down Expand Up @@ -38,6 +40,12 @@ var (
allowedNetworksMetrics atomic.Value
proxyHandler atomic.Value
allowPing atomic.Bool

// gracefulShutdownTimeout stores the configured shutdown timeout
gracefulShutdownTimeout time.Duration

// activeConnections tracks the number of currently active HTTP connections
activeConnections atomic.Int64
)

func main() {
Expand Down Expand Up @@ -75,16 +83,35 @@ func main() {
autocertManager = newAutocertManager(server.HTTPS.Autocert)
}

gracefulShutdownTimeout = getGracefulShutdownTimeout(server.GracefulShutdownTimeout)
log.Infof("Graceful shutdown timeout: %s", gracefulShutdownTimeout)

notifyReady()

var httpServer, httpsServer *http.Server
serverErrors := make(chan error, 2)

if len(server.HTTPS.ListenAddr) != 0 {
go serveTLS(server.HTTPS)
httpsServer = startTLSServer(server.HTTPS, serverErrors)
}
if len(server.HTTP.ListenAddr) != 0 {
go serve(server.HTTP)
httpServer = startHTTPServer(server.HTTP, serverErrors)
}

select {}
if err := waitForShutdownSignal(httpServer, httpsServer, serverErrors); err != nil {
log.Errorf("Shutdown error: %s", err)
os.Exit(1)
}
}

// getGracefulShutdownTimeout returns the graceful shutdown timeout from config.
func getGracefulShutdownTimeout(configTimeout config.Duration) time.Duration {
const defaultTimeout = 25 * time.Second

if configTimeout > 0 {
return time.Duration(configTimeout)
}
return defaultTimeout
}

func notifyReady() {
Expand Down Expand Up @@ -113,6 +140,109 @@ func setupReloadConfigWatch() {
}()
}

// waitForShutdownSignal waits for SIGTERM or SIGINT and performs graceful shutdown
func waitForShutdownSignal(httpServer, httpsServer *http.Server, serverErrors <-chan error) error {
sigCtx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
defer stop()

select {
case err := <-serverErrors:
return fmt.Errorf("server error: %w", err)
case <-sigCtx.Done():
log.Infof("Shutdown signal received")
return gracefulShutdown(httpServer, httpsServer, serverErrors)
}

return nil
}

// gracefulShutdown performs graceful shutdown of HTTP servers
func gracefulShutdown(httpServer, httpsServer *http.Server, serverErrors <-chan error) error {
initialConns := activeConnections.Load()
log.Infof("Starting graceful shutdown with %d open connections", initialConns)

// Create shutdown deadline
ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout)
defer cancel()

go func() {
for {
select {
case err, ok := <-serverErrors:
if !ok {
return
}
if err != nil {
log.Errorf("Server error during shutdown: %s", err)
}
case <-ctx.Done():
return
}
}
}()

// Shutdown servers concurrently
var wg sync.WaitGroup
var httpErr, httpsErr error

shutdownServer := func(s *http.Server, label string, errp *error) {
if s == nil {
return
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("Shutting down %s server...", label)
if err := s.Shutdown(ctx); err != nil {
*errp = fmt.Errorf("%s shutdown: %w", label, err)
} else {
log.Infof("%s server stopped", label)
}
}()
}
shutdownServer(httpServer, "HTTP", &httpErr)
shutdownServer(httpsServer, "HTTPS", &httpsErr)

// Signal channel for shutdown completion
shutdownComplete := make(chan struct{})
go func() {
wg.Wait()

// Clean up proxy resources
if proxy != nil {
log.Infof("Closing proxy resources...")
if err := proxy.close(); err != nil {
log.Errorf("Proxy close error: %s", err)
}
}
close(shutdownComplete)
}()

// Wait for shutdown to complete or timeout
select {
case <-shutdownComplete:
finalConns := activeConnections.Load()
joinedErrs := errors.Join(httpErr, httpsErr)

if joinedErrs != nil {
return fmt.Errorf("shutdown completed with errors (remaining open connections: %d): %w", finalConns, joinedErrs)
}
if finalConns > 0 {
log.Errorf("Graceful shutdown completed with %d open connections still active", finalConns)
} else if initialConns > 0 {
log.Infof("Graceful shutdown completed successfully (all connections closed)")
} else {
log.Infof("Graceful shutdown completed successfully")
}
return nil
case <-ctx.Done():
remainingConns := activeConnections.Load()
return fmt.Errorf("shutdown timeout exceeded with %d open connections still active", remainingConns)
}

return nil
}

var autocertManager *autocert.Manager

func newAutocertManager(cfg config.Autocert) *autocert.Manager {
Expand Down Expand Up @@ -154,7 +284,8 @@ func newListener(listenAddr string) net.Listener {
return ln
}

func serveTLS(cfg config.HTTPS) {
// startTLSServer starts the HTTPS server and returns the server instance for graceful shutdown
func startTLSServer(cfg config.HTTPS, serverErrors chan<- error) *http.Server {
ln := newListener(cfg.ListenAddr)

h := http.HandlerFunc(serveHTTP)
Expand All @@ -164,13 +295,21 @@ func serveTLS(cfg config.HTTPS) {
log.Fatalf("cannot build TLS config: %s", err)
}
tln := tls.NewListener(ln, tlsCfg)

server := newServer(tln, h, cfg.TimeoutCfg)
log.Infof("Serving https on %q", cfg.ListenAddr)
if err := listenAndServe(tln, h, cfg.TimeoutCfg); err != nil {
log.Fatalf("TLS server error on %q: %s", cfg.ListenAddr, err)
}

go func() {
if err := server.Serve(tln); err != nil && err != http.ErrServerClosed {
serverErrors <- fmt.Errorf("TLS server error on %q: %w", cfg.ListenAddr, err)
}
}()

return server
}

func serve(cfg config.HTTP) {
// startHTTPServer starts the HTTP server and returns the server instance for graceful shutdown
func startHTTPServer(cfg config.HTTP, serverErrors chan<- error) *http.Server {
var h http.Handler
ln := newListener(cfg.ListenAddr)

Expand All @@ -187,10 +326,17 @@ func serve(cfg config.HTTP) {
}
h = autocertManager.HTTPHandler(h)
}

server := newServer(ln, h, cfg.TimeoutCfg)
log.Infof("Serving http on %q", cfg.ListenAddr)
if err := listenAndServe(ln, h, cfg.TimeoutCfg); err != nil {
log.Fatalf("HTTP server error on %q: %s", cfg.ListenAddr, err)
}

go func() {
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
serverErrors <- fmt.Errorf("HTTP server error on %q: %w", cfg.ListenAddr, err)
}
}()

return server
}

func newServer(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) *http.Server {
Expand All @@ -201,17 +347,22 @@ func newServer(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) *http.Ser
ReadTimeout: time.Duration(cfg.ReadTimeout),
WriteTimeout: time.Duration(cfg.WriteTimeout),
IdleTimeout: time.Duration(cfg.IdleTimeout),
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
activeConnections.Add(1)
log.Debugf("Connection opened from %s (active: %d)", conn.RemoteAddr(), activeConnections.Load())
case http.StateClosed, http.StateHijacked:
activeConnections.Add(-1)
log.Debugf("Connection closed from %s (active: %d)", conn.RemoteAddr(), activeConnections.Load())
}
},
// Suppress error logging from the server, since chproxy
// must handle all these errors in the code.
ErrorLog: log.NilLogger,
}
}

func listenAndServe(ln net.Listener, h http.Handler, cfg config.TimeoutCfg) error {
s := newServer(ln, h, cfg)
return s.Serve(ln)
}

var promHandler = promhttp.Handler()

//nolint:cyclop //TODO reduce complexity here.
Expand Down
Loading