diff --git a/pkg/streamer/console/console.go b/pkg/streamer/console/console.go index d0bea74..3a7e3cd 100644 --- a/pkg/streamer/console/console.go +++ b/pkg/streamer/console/console.go @@ -70,6 +70,7 @@ type Streamer struct { port int forceAttach bool host string + addresses []net.IP credentials credentials.Credentials portCredentials credentials.Credentials logger *zap.Logger @@ -122,9 +123,10 @@ func NewStreamer(host, consolePort string, credentials credentials.Credentials, currentSpeed: 0, redirectLimit: defaultRedirectLimit, redirectNo: 0, - port: defaultConserverPort, // дальше port будет меняться в случае редиректа forceAttach: false, host: host, + port: defaultConserverPort, + addresses: nil, credentials: credentials, portCredentials: portCredentials, logger: nil, @@ -168,6 +170,13 @@ func WithPort(port int) StreamerOption { } } +// WithAddresses makes streamer use given addresses for connection instead of host resolution +func WithAddresses(addresses []net.IP) StreamerOption { + return func(h *Streamer) { + h.addresses = addresses + } +} + func WithForceAttache() StreamerOption { return func(h *Streamer) { h.forceAttach = true @@ -198,12 +207,6 @@ func WithSSHTunnelConn(tunnel sshtunnel.Tunnel) StreamerOption { } } -func WithConsolePort(port int) StreamerOption { - return func(h *Streamer) { - h.port = port - } -} - func WithSSHTunnel(tunnelHost string) StreamerOption { return func(h *Streamer) { h.tunnelHost = tunnelHost @@ -572,10 +575,18 @@ func (m *Streamer) stopBufferReader() error { func (m *Streamer) setupConnection(ctx context.Context) error { logger := m.logger.With(zap.String("host", m.host), zap.Int("port", m.port)) - remote := fmt.Sprintf("%s:%d", m.host, m.port) + var endpoints []string + if m.addresses != nil { + endpoints = make([]string, 0, len(m.addresses)) + for _, v := range m.addresses { + endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port))) + } + } else { + endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))} + } if m.tunnel != nil || len(m.tunnelHost) > 0 { - logger.Debug("open connection", zap.Any("tunnel", m.tunnel)) if m.tunnel == nil { + logger.Debug("open tunnel", zap.String("tunnel", m.tunnelHost)) m.tunnel = sshtunnel.NewSSHTunnel(m.tunnelHost, m.credentials) } if !m.tunnel.IsConnected() { @@ -584,17 +595,30 @@ func (m *Streamer) setupConnection(ctx context.Context) error { return err } } - conn, err := m.tunnel.StartForward(sshtunnel.TCP, remote) - if err != nil { - return fmt.Errorf("tunnel error %w", err) + for i, v := range endpoints { + logger.Debug("open tunnel connection", zap.String("host", v)) + conn, err := m.tunnel.StartForward(v) + if err == nil { + m.conn = conn + break + } + if i == len(endpoints)-1 { + return fmt.Errorf("tunnel forward error %w", err) + } + logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err)) } - m.conn = conn } else { logger.Debug("open connection") - var err error - m.conn, err = streamer.TCPDialCtx(ctx, "tcp", remote) - if err != nil { - return err + for i, v := range endpoints { + conn, err := streamer.TCPDialCtx(ctx, "tcp", v) + if err == nil { + m.conn = conn + break + } + if i == len(endpoints)-1 { + return fmt.Errorf("failed to dial all given endpoints: %w", err) + } + logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err)) } } diff --git a/pkg/streamer/rfc2217/rfc2217.go b/pkg/streamer/rfc2217/rfc2217.go index 289563b..42694c4 100644 --- a/pkg/streamer/rfc2217/rfc2217.go +++ b/pkg/streamer/rfc2217/rfc2217.go @@ -69,6 +69,7 @@ type Streamer struct { logger *zap.Logger host string port int + addresses []net.IP conn net.Conn stdoutBuffer chan []byte stdoutBufferExtra []byte @@ -112,12 +113,27 @@ func (m *Streamer) SetCredentialsInterceptor(inter func(credentials.Credentials) } func (m *Streamer) Init(ctx context.Context) error { - m.logger.Debug("open connection", zap.String("host", m.host), zap.Int("port", m.port)) - conn, err := streamer.TCPDialCtx(ctx, "tcp", net.JoinHostPort(m.host, strconv.Itoa(m.port))) - if err != nil { - return err + var endpoints []string + if len(m.addresses) != 0 { + endpoints = make([]string, 0, len(m.addresses)) + for _, v := range m.addresses { + endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port))) + } + } else { + endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))} + } + for i, v := range endpoints { + m.logger.Debug("open connection", zap.String("endpoint", v)) + conn, err := streamer.TCPDialCtx(ctx, "tcp", net.JoinHostPort(m.host, strconv.Itoa(m.port))) + if err == nil { + m.conn = conn + break + } + if i == len(endpoints)-1 { + return fmt.Errorf("failed to dial all given endpoints: %w", err) + } + m.logger.Debug("failed to dial, trying next", zap.String("endpoint", v), zap.Error(err)) } - m.conn = conn // https://github.com/pyserial/pyserial/blob/master/serial/rfc2217.py#L430 mandadoryOptions := []telnetOption{ NewTelnetOption("we-BINARY", telnet.BBINARY, telnet.BWILL, telnet.BWONT, telnet.BDO, telnet.BDONT, INACTIVE, nil), @@ -200,6 +216,7 @@ func NewStreamer(host string, port int, credentials credentials.Credentials, opt logger: zap.NewNop(), host: host, port: port, + addresses: nil, conn: nil, stdoutBuffer: stdoutBuffer, stdoutBufferExtra: nil, @@ -264,6 +281,13 @@ func WithTrace(trace trace.CB) StreamerOption { } } +// WithAddresses makes streamer use given addresses for connection instead of host resolution +func WithAddresses(addresses []net.IP) StreamerOption { + return func(h *Streamer) { + h.addresses = addresses + } +} + func (m *Streamer) Close() { if m.conn != nil { _ = m.conn.Close() diff --git a/pkg/streamer/ssh/ssh.go b/pkg/streamer/ssh/ssh.go index 6ceaa8c..a5101c2 100644 --- a/pkg/streamer/ssh/ssh.go +++ b/pkg/streamer/ssh/ssh.go @@ -35,14 +35,6 @@ import ( "github.com/annetutil/gnetcli/pkg/trace" ) -type Network string - -const ( - TCP Network = "tcp" - TCPv4 Network = "tcp4" - TCPv6 Network = "tcp6" -) - const ( defaultPort = 22 defaultReadTimeout = 20 * time.Second @@ -100,32 +92,10 @@ type terminalParams struct { echo bool } -type Endpoint struct { - Host string - Port int - Network Network -} - -func (endpoint Endpoint) String() string { - return fmt.Sprintf("{host: %s, port: %d, network: %s}", endpoint.Host, endpoint.Port, endpoint.Network) -} - -func (endpoint *Endpoint) Addr() string { - return net.JoinHostPort(endpoint.Host, strconv.Itoa(endpoint.Port)) -} - -func NewEndpoint(host string, port int, network Network) Endpoint { - res := Endpoint{ - Host: host, - Port: port, - Network: network, - } - return res -} - type Streamer struct { - endpoint Endpoint - additionalEndpoints []Endpoint + host string + port int + addresses []net.IP credentials credentials.Credentials logger *zap.Logger conn sshClient @@ -181,8 +151,9 @@ func (m *Streamer) SetTerminalEcho(e bool) { func NewStreamer(host string, credentials credentials.Credentials, opts ...StreamerOption) *Streamer { h := &Streamer{ - endpoint: NewEndpoint(host, defaultPort, TCP), - additionalEndpoints: []Endpoint{}, + host: host, + port: defaultPort, + addresses: nil, credentials: credentials, logger: nil, conn: nil, @@ -386,17 +357,17 @@ func WithLogger(log *zap.Logger) StreamerOption { } } -// WithPort sets port for default endpoint -func WithPort(port int) StreamerOption { - return func(h *Streamer) { - h.endpoint.Port = port +// WithAddresses makes streamer use given addresses for connection instead of host resolution +func WithAddresses(addresses []net.IP) StreamerOption { + return func(s *Streamer) { + s.addresses = addresses } } -// WithNetwork sets network for default endpoint -func WithNetwork(network Network) StreamerOption { +// WithPort sets port for connection +func WithPort(port int) StreamerOption { return func(h *Streamer) { - h.endpoint.Network = network + h.port = port } } @@ -426,14 +397,6 @@ func WithEnv(key, value string) StreamerOption { } } -// WithAdditionalEndpoints adds slice of endpoints that Streamer will sequentially try to connect to until success of dial, -// if original host dial fails -func WithAdditionalEndpoints(endpoints []Endpoint) StreamerOption { - return func(h *Streamer) { - h.additionalEndpoints = endpoints - } -} - func (m *Streamer) Close() { m.forwardAgent = nil if m.session != nil && m.session.session != nil { @@ -632,7 +595,7 @@ func (m *Streamer) openConnect(ctx context.Context) (sshClient, error) { // TODO: add support additionalEndpoints conn, err = OpenControl(m.controlFile) } else { - conn, err = DialCtx(ctx, m.endpoint, m.additionalEndpoints, conf, m.logger) + conn, err = DialCtx(ctx, m.host, m.port, m.addresses, conf, m.logger) } return conn, err @@ -647,24 +610,32 @@ func (m *Streamer) dialTunnel(ctx context.Context, conf *ssh.ClientConfig) (*ssh } var tunConn net.Conn var err error - var connectedEndpoint Endpoint - endpoints := append([]Endpoint{m.endpoint}, m.additionalEndpoints...) + var connectedEndpoint string + var endpoints []string + if len(m.addresses) != 0 { + endpoints = make([]string, 0, len(m.addresses)) + for _, v := range m.addresses { + endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port))) + } + } else { + endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))} + } for _, endpoint := range endpoints { - connectedEndpoint = endpoint - tunConn, err = m.tunnel.StartForward(endpoint.Network, endpoint.Addr()) + tunConn, err = m.tunnel.StartForward(endpoint) if err == nil { + connectedEndpoint = endpoint break } - m.logger.Debug("failed to open tunnel for endpoint", zap.String("address", endpoint.String()), zap.Error(err)) + m.logger.Debug("failed to open tunnel for endpoint", zap.String("address", endpoint), zap.Error(err)) } if err != nil { m.tunnel.Close() - return nil, fmt.Errorf("failed to open tunnel for any of given hosts: %v, last error: %w", m.endpoint, err) + return nil, fmt.Errorf("failed to open tunnel for any of given hosts: %v, last error: %w", endpoints, err) } - m.logger.Debug("dial tunnel", zap.String("address", connectedEndpoint.String())) - res, err := DialConnCtx(ctx, tunConn, connectedEndpoint.Addr(), conf) + m.logger.Debug("dial tunnel", zap.String("address", connectedEndpoint)) + res, err := DialConnCtx(ctx, tunConn, connectedEndpoint, conf) if err != nil { - return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint.String(), err) + return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint, err) } return res, nil } @@ -793,7 +764,7 @@ func (m *Streamer) Init(ctx context.Context) error { return fmt.Errorf("already inited") } m.inited = true - m.logger.Debug("open connection", zap.Stringer("endpoint", m.endpoint), zap.Stringers("additional endpoints", m.additionalEndpoints)) + m.logger.Debug("open connection", zap.String("host", m.host), zap.Int("port", m.port), zap.Stringers("addresses", m.addresses)) conn, err := m.openConnect(ctx) if err != nil { @@ -1134,28 +1105,36 @@ func (m *Streamer) uploadSftp(filePaths map[string]streamer.File, useSudo bool) } // DialCtx ssh.Dial version with context arg -func DialCtx(ctx context.Context, endpoint Endpoint, additionalEndpoints []Endpoint, config *ssh.ClientConfig, logger *zap.Logger) (*ssh.Client, error) { +func DialCtx(ctx context.Context, host string, port int, addresses []net.IP, config *ssh.ClientConfig, logger *zap.Logger) (*ssh.Client, error) { var err error var conn net.Conn - var connectedEndpoint Endpoint - endpoints := append([]Endpoint{endpoint}, additionalEndpoints...) + var connectedEndpoint string + var endpoints []string + if len(addresses) != 0 { + endpoints = make([]string, 0, len(addresses)) + for _, v := range addresses { + endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(port))) + } + } else { + endpoints = []string{net.JoinHostPort(host, strconv.Itoa(port))} + } for _, endpoint := range endpoints { - connectedEndpoint = endpoint - logger.Debug("tcp dial", zap.String("address", connectedEndpoint.String())) - conn, err = streamer.TCPDialCtx(ctx, string(endpoint.Network), endpoint.Addr()) + logger.Debug("tcp dial", zap.String("address", endpoint)) + conn, err = streamer.TCPDialCtx(ctx, "tcp", endpoint) if err == nil { + connectedEndpoint = endpoint break } // always continue attempts to connect in case of dial failure - logger.Debug("dial failed for endpoint", zap.String("endpoint", endpoint.String()), zap.Error(err)) + logger.Debug("dial failed for endpoint", zap.String("endpoint", endpoint), zap.Error(err)) } if err != nil { - return nil, fmt.Errorf("failed to dial any of given endpoints: %v, last error: %w", endpoint, err) + return nil, fmt.Errorf("failed to dial any of given endpoints: %v, last error: %w", endpoints, err) } - logger.Debug("tcp ssh", zap.String("address", connectedEndpoint.String())) - res, err := DialConnCtx(ctx, conn, connectedEndpoint.Addr(), config) + logger.Debug("tcp ssh", zap.String("address", connectedEndpoint)) + res, err := DialConnCtx(ctx, conn, connectedEndpoint, config) if err != nil { - return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint.String(), err) + return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint, err) } return res, err } diff --git a/pkg/streamer/ssh/ssh_control_file.go b/pkg/streamer/ssh/ssh_control_file.go index d87e7f9..b404e59 100644 --- a/pkg/streamer/ssh/ssh_control_file.go +++ b/pkg/streamer/ssh/ssh_control_file.go @@ -24,8 +24,8 @@ func resolveHomeDir(path string) string { return path } -func dialControlMasterConf(_ context.Context, controlFile string, endpoint Endpoint, conf *ssh.ClientConfig, logger *zap.Logger) (*ControlConn, error) { - params := tssh.NewSshParam(endpoint.Host, strconv.Itoa(endpoint.Port), conf.User, nil) +func dialControlMasterConf(_ context.Context, controlFile string, host string, port int, conf *ssh.ClientConfig, logger *zap.Logger) (*ControlConn, error) { + params := tssh.NewSshParam(host, strconv.Itoa(port), conf.User, nil) expandedPath, err := tssh.ExpandTokens(controlFile, params, "%CdhijkLlnpru") if err != nil { return nil, err diff --git a/pkg/streamer/ssh/ssh_test.go b/pkg/streamer/ssh/ssh_test.go index 6e66702..96e802b 100644 --- a/pkg/streamer/ssh/ssh_test.go +++ b/pkg/streamer/ssh/ssh_test.go @@ -14,33 +14,3 @@ func TestSSHInterface(t *testing.T) { _, ok := interface{}(&val).(streamer.Connector) assert.True(t, ok, "not a Connector interface") } - -func TestEndpoint_Addr(t *testing.T) { - tests := []struct { - name string - endpoint Endpoint - expected string - }{ - { - name: "default", - endpoint: Endpoint{Host: "localhost", Port: 22}, - expected: "localhost:22", - }, - { - name: "custom port", - endpoint: Endpoint{Host: "example.com", Port: 2222}, - expected: "example.com:2222", - }, - { - name: "IPv6", - endpoint: Endpoint{Host: "2001:db8::1", Port: 22}, - expected: "[2001:db8::1]:22", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expected, tt.endpoint.Addr()) - }) - } -} diff --git a/pkg/streamer/ssh/ssh_tunnel.go b/pkg/streamer/ssh/ssh_tunnel.go index 76a4388..abcccfc 100644 --- a/pkg/streamer/ssh/ssh_tunnel.go +++ b/pkg/streamer/ssh/ssh_tunnel.go @@ -22,11 +22,13 @@ type Tunnel interface { Close() IsConnected() bool CreateConnect(context.Context) error - StartForward(network Network, addr string) (net.Conn, error) + StartForward(addr string) (net.Conn, error) } type SSHTunnel struct { - Server Endpoint + Host string + Port int + Addresses []net.IP Config *ssh.ClientConfig svrConn *ssh.Client stdioForward *ControlConn @@ -39,7 +41,8 @@ type SSHTunnel struct { func NewSSHTunnel(host string, credentials credentials.Credentials, opts ...SSHTunnelOption) *SSHTunnel { h := &SSHTunnel{ - Server: NewEndpoint(host, defaultPort, TCP), + Host: host, + Port: defaultPort, Config: nil, svrConn: nil, isOpen: false, @@ -68,15 +71,16 @@ func SSHTunnelWithControlFIle(path string) SSHTunnelOption { } } -func SSHTunnelWithNetwork(network Network) SSHTunnelOption { +func SSHTunnelWithPort(port int) SSHTunnelOption { return func(h *SSHTunnel) { - h.Server.Network = network + h.Port = port } } -func SSHTunnelWitPort(port int) SSHTunnelOption { +// SSHTunnelWithAddresses makes ssh tunnel use given addresses for connection instead of host resolution +func SSHTunnelWithAddresses(addresses []net.IP) SSHTunnelOption { return func(h *SSHTunnel) { - h.Server.Port = port + h.Addresses = addresses } } @@ -89,7 +93,7 @@ func (m *SSHTunnel) CreateConnect(ctx context.Context) error { if len(m.controlFile) > 0 { strOpts = append(strOpts, WithSSHControlFIle(m.controlFile)) } - connector := NewStreamer(m.Server.Host, m.credentials, strOpts...) + connector := NewStreamer(m.Host, m.credentials, strOpts...) conf, err := connector.GetConfig(ctx) if err != nil { m.logger.Error(err.Error()) @@ -100,14 +104,14 @@ func (m *SSHTunnel) CreateConnect(ctx context.Context) error { var conn *ssh.Client if len(m.controlFile) != 0 { - mConn, err := dialControlMasterConf(ctx, m.controlFile, m.Server, conf, m.logger) + mConn, err := dialControlMasterConf(ctx, m.controlFile, m.Host, m.Port, conf, m.logger) if err != nil { return err } m.stdioForward = mConn conn = nil } else { - conn, err = DialCtx(ctx, m.Server, nil, m.Config, m.logger) + conn, err = DialCtx(ctx, m.Host, m.Port, m.Addresses, m.Config, m.logger) } if err != nil { m.logger.Debug("unable to connect to tunnel", zap.Error(err)) @@ -116,13 +120,13 @@ func (m *SSHTunnel) CreateConnect(ctx context.Context) error { } return err } - m.logger.Debug("connected to tunnel", zap.String("server", m.Server.String())) + m.logger.Debug("connected to tunnel", zap.String("host", m.Host), zap.Int("port", m.Port)) m.svrConn = conn m.isOpen = true return nil } -func (m *SSHTunnel) StartForward(network Network, remoteAddr string) (net.Conn, error) { +func (m *SSHTunnel) StartForward(remoteAddr string) (net.Conn, error) { if m.stdioForward != nil { host, port, err := net.SplitHostPort(remoteAddr) if err != nil { @@ -145,7 +149,7 @@ func (m *SSHTunnel) StartForward(network Network, remoteAddr string) (net.Conn, if err != nil { return nil, err } - remoteConn, err := m.svrConn.Dial(string(network), remoteAddr) + remoteConn, err := m.svrConn.Dial("tcp", remoteAddr) if err != nil { return nil, err } diff --git a/pkg/streamer/telnet/telnet.go b/pkg/streamer/telnet/telnet.go index eeccf7a..1d31e5b 100644 --- a/pkg/streamer/telnet/telnet.go +++ b/pkg/streamer/telnet/telnet.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "strconv" "time" "go.uber.org/zap" @@ -59,6 +60,7 @@ type Streamer struct { credentials credentials.Credentials logger *zap.Logger host string + addresses []net.IP conn net.Conn stdoutBuffer chan []byte stdoutBufferExtra []byte @@ -96,11 +98,29 @@ func (m *Streamer) SetCredentialsInterceptor(inter func(credentials.Credentials) func (m *Streamer) Init(ctx context.Context) error { m.logger.Debug("open connection", zap.String("host", m.host)) - conn, err := streamer.TCPDialCtx(ctx, "tcp", fmt.Sprintf("%s:%d", m.host, defaultPort)) - if err != nil { - return err + var endpoints []string + if len(m.addresses) != 0 { + endpoints = make([]string, 0, len(m.addresses)) + for _, v := range m.addresses { + endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(defaultPort))) + } + } else { + endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(defaultPort))} } - m.conn = conn + + for i, v := range endpoints { + conn, err := streamer.TCPDialCtx(ctx, "tcp", v) + if err == nil { + m.conn = conn + break + } + if i == len(endpoints)-1 { + return fmt.Errorf("failed to dial all given endpoints %v: %w", endpoints, err) + } + + m.logger.Debug("failed to connect endpoint, trying next", zap.String("endpoint", v), zap.Error(err)) + } + eg, _ := errgroup.WithContext(ctx) eg.Go(func() error { return m.stdoutReader(m.conn) }) return nil @@ -116,6 +136,7 @@ func NewStreamer(host string, credentials credentials.Credentials, opts ...Strea credentials: credentials, logger: zap.NewNop(), host: host, + addresses: nil, conn: nil, stdoutBuffer: stdoutBuffer, stdoutBufferExtra: nil, @@ -179,6 +200,13 @@ func WithTrace(trace trace.CB) StreamerOption { } } +// WithAddresses makes streamer use given addresses for connection instead of host resolution +func WithAddresses(addresses []net.IP) StreamerOption { + return func(h *Streamer) { + h.addresses = addresses + } +} + func (m *Streamer) Close() { if m.conn != nil { _ = m.conn.Close()