From 89519abbfe1834bd4e813f3831bb2ed6fe91de86 Mon Sep 17 00:00:00 2001 From: Andrei Gavrilov Date: Wed, 10 Dec 2025 11:55:52 +0300 Subject: [PATCH 1/3] add timeout for streamer connect (for working multiadress setup) --- pkg/streamer/streamer.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/streamer/streamer.go b/pkg/streamer/streamer.go index 145820e..533088a 100644 --- a/pkg/streamer/streamer.go +++ b/pkg/streamer/streamer.go @@ -139,6 +139,8 @@ func NewReadResImpl(before, after []byte, matchedGroups map[string][]byte, match // TCPDialCtx net.Dial version with context arg func TCPDialCtx(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{} + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() conn, err := d.DialContext(ctx, network, addr) if err != nil { return nil, err From 54278c2aff2684aad9ee7ca0f7dac9886b9db527 Mon Sep 17 00:00:00 2001 From: Andrei Gavrilov Date: Wed, 10 Dec 2025 11:57:01 +0300 Subject: [PATCH 2/3] add default timeout for connect --- pkg/streamer/streamer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/streamer/streamer.go b/pkg/streamer/streamer.go index 533088a..7123b15 100644 --- a/pkg/streamer/streamer.go +++ b/pkg/streamer/streamer.go @@ -139,7 +139,7 @@ func NewReadResImpl(before, after []byte, matchedGroups map[string][]byte, match // TCPDialCtx net.Dial version with context arg func TCPDialCtx(ctx context.Context, network, addr string) (net.Conn, error) { d := net.Dialer{} - ctx, cancel := context.WithTimeout(ctx, time.Second*5) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() conn, err := d.DialContext(ctx, network, addr) if err != nil { From 65aa032df895a5b207b888677847bfbdcf9feef0 Mon Sep 17 00:00:00 2001 From: Andrei Gavrilov Date: Wed, 10 Dec 2025 12:42:15 +0300 Subject: [PATCH 3/3] remember connected address in colsole port discovery --- pkg/streamer/console/console.go | 84 +++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/pkg/streamer/console/console.go b/pkg/streamer/console/console.go index 3a7e3cd..c8c7fbb 100644 --- a/pkg/streamer/console/console.go +++ b/pkg/streamer/console/console.go @@ -62,19 +62,22 @@ const ( ) type Streamer struct { - consolePort string - speed int - currentSpeed int - redirectLimit int - redirectNo int - port int - forceAttach bool - host string - addresses []net.IP - credentials credentials.Credentials - portCredentials credentials.Credentials - logger *zap.Logger - conn net.Conn + consolePort string + speed int + currentSpeed int + redirectLimit int + redirectNo int + port int + forceAttach bool + host string + addresses []net.IP + credentials credentials.Credentials + portCredentials credentials.Credentials + logger *zap.Logger + conn net.Conn + // connectedAddress is the first address that Streamer actually managed to connect. + // Needed to not iterate again over each address during port discovery + connectedAddress string buffer chan []byte readerWg *errgroup.Group readerCancel context.CancelFunc @@ -131,6 +134,7 @@ func NewStreamer(host, consolePort string, credentials credentials.Credentials, portCredentials: portCredentials, logger: nil, conn: nil, + connectedAddress: "", buffer: nil, // buffer for catching console's messages readerWg: &errgroup.Group{}, readerCancel: nil, @@ -573,17 +577,43 @@ func (m *Streamer) stopBufferReader() error { return nil } -func (m *Streamer) setupConnection(ctx context.Context) error { - logger := m.logger.With(zap.String("host", m.host), zap.Int("port", m.port)) - var endpoints []string - if m.addresses != nil { - endpoints = make([]string, 0, len(m.addresses)) +type endpoint struct { + address string + port int +} + +func (e *endpoint) HostPort() string { + return net.JoinHostPort(e.address, strconv.Itoa(e.port)) +} + +func (m *Streamer) getEndpoints() []endpoint { + if len(m.connectedAddress) != 0 { + return []endpoint{{ + address: m.connectedAddress, + port: m.port, + }} + } + if len(m.addresses) != 0 { + endpoints := make([]endpoint, 0, len(m.addresses)) for _, v := range m.addresses { - endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port))) + endpoints = append(endpoints, endpoint{ + address: v.String(), + port: m.port, + }) } - } else { - endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))} + return endpoints + } + return []endpoint{ + { + address: m.host, + port: m.port, + }, } +} + +func (m *Streamer) setupConnection(ctx context.Context) error { + logger := m.logger.With(zap.String("host", m.host), zap.Int("port", m.port)) + endpoints := m.getEndpoints() if m.tunnel != nil || len(m.tunnelHost) > 0 { if m.tunnel == nil { logger.Debug("open tunnel", zap.String("tunnel", m.tunnelHost)) @@ -596,29 +626,31 @@ func (m *Streamer) setupConnection(ctx context.Context) error { } } for i, v := range endpoints { - logger.Debug("open tunnel connection", zap.String("host", v)) - conn, err := m.tunnel.StartForward(v) + logger.Debug("open tunnel connection", zap.String("host", v.HostPort())) + conn, err := m.tunnel.StartForward(v.HostPort()) if err == nil { + m.connectedAddress = v.address m.conn = conn break } if i == len(endpoints)-1 { return fmt.Errorf("tunnel forward error %w", err) } - logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err)) + logger.Debug("failed to connect endpoint, trying next", zap.String("remote endpoint", v.HostPort()), zap.Error(err)) } } else { logger.Debug("open connection") for i, v := range endpoints { - conn, err := streamer.TCPDialCtx(ctx, "tcp", v) + conn, err := streamer.TCPDialCtx(ctx, "tcp", v.HostPort()) if err == nil { + m.connectedAddress = v.address m.conn = conn break } if i == len(endpoints)-1 { return fmt.Errorf("failed to dial all given endpoints: %w", err) } - logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err)) + logger.Debug("failed to connect endpoint, trying next", zap.String("remote endpoint", v.HostPort()), zap.Error(err)) } }