Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 58 additions & 26 deletions pkg/streamer/console/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,22 @@ const (
)

type Streamer struct {
consolePort string
speed int
currentSpeed int
redirectLimit int
redirectNo int
port int
forceAttach bool
host string
addresses []net.IP
credentials credentials.Credentials
portCredentials credentials.Credentials
logger *zap.Logger
conn net.Conn
consolePort string
speed int
currentSpeed int
redirectLimit int
redirectNo int
port int
forceAttach bool
host string
addresses []net.IP
credentials credentials.Credentials
portCredentials credentials.Credentials
logger *zap.Logger
conn net.Conn
// connectedAddress is the first address that Streamer actually managed to connect.
// Needed to not iterate again over each address during port discovery
connectedAddress string
buffer chan []byte
readerWg *errgroup.Group
readerCancel context.CancelFunc
Expand Down Expand Up @@ -131,6 +134,7 @@ func NewStreamer(host, consolePort string, credentials credentials.Credentials,
portCredentials: portCredentials,
logger: nil,
conn: nil,
connectedAddress: "",
buffer: nil, // buffer for catching console's messages
readerWg: &errgroup.Group{},
readerCancel: nil,
Expand Down Expand Up @@ -573,17 +577,43 @@ func (m *Streamer) stopBufferReader() error {
return nil
}

func (m *Streamer) setupConnection(ctx context.Context) error {
logger := m.logger.With(zap.String("host", m.host), zap.Int("port", m.port))
var endpoints []string
if m.addresses != nil {
endpoints = make([]string, 0, len(m.addresses))
type endpoint struct {
address string
port int
}

func (e *endpoint) HostPort() string {
return net.JoinHostPort(e.address, strconv.Itoa(e.port))
}

func (m *Streamer) getEndpoints() []endpoint {
if len(m.connectedAddress) != 0 {
return []endpoint{{
address: m.connectedAddress,
port: m.port,
}}
}
if len(m.addresses) != 0 {
endpoints := make([]endpoint, 0, len(m.addresses))
for _, v := range m.addresses {
endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port)))
endpoints = append(endpoints, endpoint{
address: v.String(),
port: m.port,
})
}
} else {
endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))}
return endpoints
}
return []endpoint{
{
address: m.host,
port: m.port,
},
}
}

func (m *Streamer) setupConnection(ctx context.Context) error {
logger := m.logger.With(zap.String("host", m.host), zap.Int("port", m.port))
endpoints := m.getEndpoints()
if m.tunnel != nil || len(m.tunnelHost) > 0 {
if m.tunnel == nil {
logger.Debug("open tunnel", zap.String("tunnel", m.tunnelHost))
Expand All @@ -596,29 +626,31 @@ func (m *Streamer) setupConnection(ctx context.Context) error {
}
}
for i, v := range endpoints {
logger.Debug("open tunnel connection", zap.String("host", v))
conn, err := m.tunnel.StartForward(v)
logger.Debug("open tunnel connection", zap.String("host", v.HostPort()))
conn, err := m.tunnel.StartForward(v.HostPort())
if err == nil {
m.connectedAddress = v.address
m.conn = conn
break
}
if i == len(endpoints)-1 {
return fmt.Errorf("tunnel forward error %w", err)
}
logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err))
logger.Debug("failed to connect endpoint, trying next", zap.String("remote endpoint", v.HostPort()), zap.Error(err))
}
} else {
logger.Debug("open connection")
for i, v := range endpoints {
conn, err := streamer.TCPDialCtx(ctx, "tcp", v)
conn, err := streamer.TCPDialCtx(ctx, "tcp", v.HostPort())
if err == nil {
m.connectedAddress = v.address
m.conn = conn
break
}
if i == len(endpoints)-1 {
return fmt.Errorf("failed to dial all given endpoints: %w", err)
}
logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err))
logger.Debug("failed to connect endpoint, trying next", zap.String("remote endpoint", v.HostPort()), zap.Error(err))
}
}

Expand Down
Loading