diff --git a/pkg/streamer/console/console.go b/pkg/streamer/console/console.go index 3a7e3cd..c8c7fbb 100644 --- a/pkg/streamer/console/console.go +++ b/pkg/streamer/console/console.go @@ -62,19 +62,22 @@ const ( ) type Streamer struct { - consolePort string - speed int - currentSpeed int - redirectLimit int - redirectNo int - port int - forceAttach bool - host string - addresses []net.IP - credentials credentials.Credentials - portCredentials credentials.Credentials - logger *zap.Logger - conn net.Conn + consolePort string + speed int + currentSpeed int + redirectLimit int + redirectNo int + port int + forceAttach bool + host string + addresses []net.IP + credentials credentials.Credentials + portCredentials credentials.Credentials + logger *zap.Logger + conn net.Conn + // connectedAddress is the first address that Streamer actually managed to connect. + // Needed to not iterate again over each address during port discovery + connectedAddress string buffer chan []byte readerWg *errgroup.Group readerCancel context.CancelFunc @@ -131,6 +134,7 @@ func NewStreamer(host, consolePort string, credentials credentials.Credentials, portCredentials: portCredentials, logger: nil, conn: nil, + connectedAddress: "", buffer: nil, // buffer for catching console's messages readerWg: &errgroup.Group{}, readerCancel: nil, @@ -573,17 +577,43 @@ func (m *Streamer) stopBufferReader() error { return nil } -func (m *Streamer) setupConnection(ctx context.Context) error { - logger := m.logger.With(zap.String("host", m.host), zap.Int("port", m.port)) - var endpoints []string - if m.addresses != nil { - endpoints = make([]string, 0, len(m.addresses)) +type endpoint struct { + address string + port int +} + +func (e *endpoint) HostPort() string { + return net.JoinHostPort(e.address, strconv.Itoa(e.port)) +} + +func (m *Streamer) getEndpoints() []endpoint { + if len(m.connectedAddress) != 0 { + return []endpoint{{ + address: m.connectedAddress, + port: m.port, + }} + } + if len(m.addresses) != 0 { + endpoints := make([]endpoint, 0, len(m.addresses)) for _, v := range m.addresses { - endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port))) + endpoints = append(endpoints, endpoint{ + address: v.String(), + port: m.port, + }) } - } else { - endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))} + return endpoints + } + return []endpoint{ + { + address: m.host, + port: m.port, + }, } +} + +func (m *Streamer) setupConnection(ctx context.Context) error { + logger := m.logger.With(zap.String("host", m.host), zap.Int("port", m.port)) + endpoints := m.getEndpoints() if m.tunnel != nil || len(m.tunnelHost) > 0 { if m.tunnel == nil { logger.Debug("open tunnel", zap.String("tunnel", m.tunnelHost)) @@ -596,29 +626,31 @@ func (m *Streamer) setupConnection(ctx context.Context) error { } } for i, v := range endpoints { - logger.Debug("open tunnel connection", zap.String("host", v)) - conn, err := m.tunnel.StartForward(v) + logger.Debug("open tunnel connection", zap.String("host", v.HostPort())) + conn, err := m.tunnel.StartForward(v.HostPort()) if err == nil { + m.connectedAddress = v.address m.conn = conn break } if i == len(endpoints)-1 { return fmt.Errorf("tunnel forward error %w", err) } - logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err)) + logger.Debug("failed to connect endpoint, trying next", zap.String("remote endpoint", v.HostPort()), zap.Error(err)) } } else { logger.Debug("open connection") for i, v := range endpoints { - conn, err := streamer.TCPDialCtx(ctx, "tcp", v) + conn, err := streamer.TCPDialCtx(ctx, "tcp", v.HostPort()) if err == nil { + m.connectedAddress = v.address m.conn = conn break } if i == len(endpoints)-1 { return fmt.Errorf("failed to dial all given endpoints: %w", err) } - logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err)) + logger.Debug("failed to connect endpoint, trying next", zap.String("remote endpoint", v.HostPort()), zap.Error(err)) } }