Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ package pool
import (
"errors"
"fmt"
"google.golang.org/grpc/connectivity"
"log"
"math"
"sync"
"sync/atomic"
"time"
)

// ErrClosed is the error resulting if the pool is closed via pool.Close().
Expand Down Expand Up @@ -103,6 +105,10 @@ func New(address string, option Options) (Pool, error) {
}
p.conns[i] = p.wrapConn(c, false)
}

// Start a health check goroutine to periodically check connections
p.startHealthCheck(time.Minute)

log.Printf("new pool success: %v\n", p.Status())

return p, nil
Expand Down Expand Up @@ -211,7 +217,7 @@ func (p *pool) Get() (Conn, error) {

// Close see Pool interface.
func (p *pool) Close() error {
atomic.StoreInt32(&p.closed, 1)
atomic.StoreInt32(&p.closed, 1) // 标记为关闭
atomic.StoreUint32(&p.index, 0)
atomic.StoreInt32(&p.current, 0)
atomic.StoreInt32(&p.ref, 0)
Expand All @@ -225,3 +231,43 @@ func (p *pool) Status() string {
return fmt.Sprintf("address:%s, index:%d, current:%d, ref:%d. option:%v",
p.address, p.index, p.current, p.ref, p.opt)
}

// Periodically check the connection pool for stale or unavailable connections
func (p *pool) startHealthCheck(interval time.Duration) {
go func() {
for {
if atomic.LoadInt32(&p.closed) == 1 {
return // Stop the goroutine when the pool is closed
}
select {
case <-time.After(interval):
p.checkAndReplaceStaleConnections()
}
}
}()
}

func (p *pool) checkAndReplaceStaleConnections() {
for i, c := range p.conns {
if c == nil {
continue
}

// Use a lock for each individual connection instead of locking the entire pool
if c.Value() == nil || c.Value().GetState() == connectivity.Shutdown {
log.Printf("Detected stale connection at index %d, replacing...", i)

p.Lock() // Lock only for the current connection
_ = c.reset()

newConn, err := p.opt.Dial(p.address)
if err != nil {
log.Printf("Failed to create new connection: %v", err)
p.Unlock()
continue
}
p.conns[i] = p.wrapConn(newConn, false)
p.Unlock()
}
}
}