Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions pkg/streamer/console/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type Streamer struct {
port int
forceAttach bool
host string
addresses []net.IP
credentials credentials.Credentials
portCredentials credentials.Credentials
logger *zap.Logger
Expand Down Expand Up @@ -122,9 +123,10 @@ func NewStreamer(host, consolePort string, credentials credentials.Credentials,
currentSpeed: 0,
redirectLimit: defaultRedirectLimit,
redirectNo: 0,
port: defaultConserverPort, // дальше port будет меняться в случае редиректа
forceAttach: false,
host: host,
port: defaultConserverPort,
addresses: nil,
credentials: credentials,
portCredentials: portCredentials,
logger: nil,
Expand Down Expand Up @@ -168,6 +170,13 @@ func WithPort(port int) StreamerOption {
}
}

// WithAddresses makes streamer use given addresses for connection instead of host resolution
func WithAddresses(addresses []net.IP) StreamerOption {
return func(h *Streamer) {
h.addresses = addresses
}
}

func WithForceAttache() StreamerOption {
return func(h *Streamer) {
h.forceAttach = true
Expand Down Expand Up @@ -198,12 +207,6 @@ func WithSSHTunnelConn(tunnel sshtunnel.Tunnel) StreamerOption {
}
}

func WithConsolePort(port int) StreamerOption {
return func(h *Streamer) {
h.port = port
}
}

func WithSSHTunnel(tunnelHost string) StreamerOption {
return func(h *Streamer) {
h.tunnelHost = tunnelHost
Expand Down Expand Up @@ -572,10 +575,18 @@ func (m *Streamer) stopBufferReader() error {

func (m *Streamer) setupConnection(ctx context.Context) error {
logger := m.logger.With(zap.String("host", m.host), zap.Int("port", m.port))
remote := fmt.Sprintf("%s:%d", m.host, m.port)
var endpoints []string
if m.addresses != nil {
endpoints = make([]string, 0, len(m.addresses))
for _, v := range m.addresses {
endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port)))
}
} else {
endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))}
}
if m.tunnel != nil || len(m.tunnelHost) > 0 {
logger.Debug("open connection", zap.Any("tunnel", m.tunnel))
if m.tunnel == nil {
logger.Debug("open tunnel", zap.String("tunnel", m.tunnelHost))
m.tunnel = sshtunnel.NewSSHTunnel(m.tunnelHost, m.credentials)
}
if !m.tunnel.IsConnected() {
Expand All @@ -584,17 +595,30 @@ func (m *Streamer) setupConnection(ctx context.Context) error {
return err
}
}
conn, err := m.tunnel.StartForward(sshtunnel.TCP, remote)
if err != nil {
return fmt.Errorf("tunnel error %w", err)
for i, v := range endpoints {
logger.Debug("open tunnel connection", zap.String("host", v))
conn, err := m.tunnel.StartForward(v)
if err == nil {
m.conn = conn
break
}
if i == len(endpoints)-1 {
return fmt.Errorf("tunnel forward error %w", err)
}
logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err))
}
m.conn = conn
} else {
logger.Debug("open connection")
var err error
m.conn, err = streamer.TCPDialCtx(ctx, "tcp", remote)
if err != nil {
return err
for i, v := range endpoints {
conn, err := streamer.TCPDialCtx(ctx, "tcp", v)
if err == nil {
m.conn = conn
break
}
if i == len(endpoints)-1 {
return fmt.Errorf("failed to dial all given endpoints: %w", err)
}
logger.Debug("failed to connect endpoint, trying next", zap.Any("remote endpoint", v), zap.Error(err))
}
}

Expand Down
34 changes: 29 additions & 5 deletions pkg/streamer/rfc2217/rfc2217.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type Streamer struct {
logger *zap.Logger
host string
port int
addresses []net.IP
conn net.Conn
stdoutBuffer chan []byte
stdoutBufferExtra []byte
Expand Down Expand Up @@ -112,12 +113,27 @@ func (m *Streamer) SetCredentialsInterceptor(inter func(credentials.Credentials)
}

func (m *Streamer) Init(ctx context.Context) error {
m.logger.Debug("open connection", zap.String("host", m.host), zap.Int("port", m.port))
conn, err := streamer.TCPDialCtx(ctx, "tcp", net.JoinHostPort(m.host, strconv.Itoa(m.port)))
if err != nil {
return err
var endpoints []string
if len(m.addresses) != 0 {
endpoints = make([]string, 0, len(m.addresses))
for _, v := range m.addresses {
endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port)))
}
} else {
endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))}
}
for i, v := range endpoints {
m.logger.Debug("open connection", zap.String("endpoint", v))
conn, err := streamer.TCPDialCtx(ctx, "tcp", net.JoinHostPort(m.host, strconv.Itoa(m.port)))
if err == nil {
m.conn = conn
break
}
if i == len(endpoints)-1 {
return fmt.Errorf("failed to dial all given endpoints: %w", err)
}
m.logger.Debug("failed to dial, trying next", zap.String("endpoint", v), zap.Error(err))
}
m.conn = conn
// https://github.com/pyserial/pyserial/blob/master/serial/rfc2217.py#L430
mandadoryOptions := []telnetOption{
NewTelnetOption("we-BINARY", telnet.BBINARY, telnet.BWILL, telnet.BWONT, telnet.BDO, telnet.BDONT, INACTIVE, nil),
Expand Down Expand Up @@ -200,6 +216,7 @@ func NewStreamer(host string, port int, credentials credentials.Credentials, opt
logger: zap.NewNop(),
host: host,
port: port,
addresses: nil,
conn: nil,
stdoutBuffer: stdoutBuffer,
stdoutBufferExtra: nil,
Expand Down Expand Up @@ -264,6 +281,13 @@ func WithTrace(trace trace.CB) StreamerOption {
}
}

// WithAddresses makes streamer use given addresses for connection instead of host resolution
func WithAddresses(addresses []net.IP) StreamerOption {
return func(h *Streamer) {
h.addresses = addresses
}
}

func (m *Streamer) Close() {
if m.conn != nil {
_ = m.conn.Close()
Expand Down
123 changes: 51 additions & 72 deletions pkg/streamer/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,6 @@ import (
"github.com/annetutil/gnetcli/pkg/trace"
)

type Network string

const (
TCP Network = "tcp"
TCPv4 Network = "tcp4"
TCPv6 Network = "tcp6"
)

const (
defaultPort = 22
defaultReadTimeout = 20 * time.Second
Expand Down Expand Up @@ -100,32 +92,10 @@ type terminalParams struct {
echo bool
}

type Endpoint struct {
Host string
Port int
Network Network
}

func (endpoint Endpoint) String() string {
return fmt.Sprintf("{host: %s, port: %d, network: %s}", endpoint.Host, endpoint.Port, endpoint.Network)
}

func (endpoint *Endpoint) Addr() string {
return net.JoinHostPort(endpoint.Host, strconv.Itoa(endpoint.Port))
}

func NewEndpoint(host string, port int, network Network) Endpoint {
res := Endpoint{
Host: host,
Port: port,
Network: network,
}
return res
}

type Streamer struct {
endpoint Endpoint
additionalEndpoints []Endpoint
host string
port int
addresses []net.IP
credentials credentials.Credentials
logger *zap.Logger
conn sshClient
Expand Down Expand Up @@ -181,8 +151,9 @@ func (m *Streamer) SetTerminalEcho(e bool) {

func NewStreamer(host string, credentials credentials.Credentials, opts ...StreamerOption) *Streamer {
h := &Streamer{
endpoint: NewEndpoint(host, defaultPort, TCP),
additionalEndpoints: []Endpoint{},
host: host,
port: defaultPort,
addresses: nil,
credentials: credentials,
logger: nil,
conn: nil,
Expand Down Expand Up @@ -386,17 +357,17 @@ func WithLogger(log *zap.Logger) StreamerOption {
}
}

// WithPort sets port for default endpoint
func WithPort(port int) StreamerOption {
return func(h *Streamer) {
h.endpoint.Port = port
// WithAddresses makes streamer use given addresses for connection instead of host resolution
func WithAddresses(addresses []net.IP) StreamerOption {
return func(s *Streamer) {
s.addresses = addresses
}
}

// WithNetwork sets network for default endpoint
func WithNetwork(network Network) StreamerOption {
// WithPort sets port for connection
func WithPort(port int) StreamerOption {
return func(h *Streamer) {
h.endpoint.Network = network
h.port = port
}
}

Expand Down Expand Up @@ -426,14 +397,6 @@ func WithEnv(key, value string) StreamerOption {
}
}

// WithAdditionalEndpoints adds slice of endpoints that Streamer will sequentially try to connect to until success of dial,
// if original host dial fails
func WithAdditionalEndpoints(endpoints []Endpoint) StreamerOption {
return func(h *Streamer) {
h.additionalEndpoints = endpoints
}
}

func (m *Streamer) Close() {
m.forwardAgent = nil
if m.session != nil && m.session.session != nil {
Expand Down Expand Up @@ -632,7 +595,7 @@ func (m *Streamer) openConnect(ctx context.Context) (sshClient, error) {
// TODO: add support additionalEndpoints
conn, err = OpenControl(m.controlFile)
} else {
conn, err = DialCtx(ctx, m.endpoint, m.additionalEndpoints, conf, m.logger)
conn, err = DialCtx(ctx, m.host, m.port, m.addresses, conf, m.logger)
}

return conn, err
Expand All @@ -647,24 +610,32 @@ func (m *Streamer) dialTunnel(ctx context.Context, conf *ssh.ClientConfig) (*ssh
}
var tunConn net.Conn
var err error
var connectedEndpoint Endpoint
endpoints := append([]Endpoint{m.endpoint}, m.additionalEndpoints...)
var connectedEndpoint string
var endpoints []string
if len(m.addresses) != 0 {
endpoints = make([]string, 0, len(m.addresses))
for _, v := range m.addresses {
endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(m.port)))
}
} else {
endpoints = []string{net.JoinHostPort(m.host, strconv.Itoa(m.port))}
}
for _, endpoint := range endpoints {
connectedEndpoint = endpoint
tunConn, err = m.tunnel.StartForward(endpoint.Network, endpoint.Addr())
tunConn, err = m.tunnel.StartForward(endpoint)
if err == nil {
connectedEndpoint = endpoint
break
}
m.logger.Debug("failed to open tunnel for endpoint", zap.String("address", endpoint.String()), zap.Error(err))
m.logger.Debug("failed to open tunnel for endpoint", zap.String("address", endpoint), zap.Error(err))
}
if err != nil {
m.tunnel.Close()
return nil, fmt.Errorf("failed to open tunnel for any of given hosts: %v, last error: %w", m.endpoint, err)
return nil, fmt.Errorf("failed to open tunnel for any of given hosts: %v, last error: %w", endpoints, err)
}
m.logger.Debug("dial tunnel", zap.String("address", connectedEndpoint.String()))
res, err := DialConnCtx(ctx, tunConn, connectedEndpoint.Addr(), conf)
m.logger.Debug("dial tunnel", zap.String("address", connectedEndpoint))
res, err := DialConnCtx(ctx, tunConn, connectedEndpoint, conf)
if err != nil {
return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint.String(), err)
return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint, err)
}
return res, nil
}
Expand Down Expand Up @@ -793,7 +764,7 @@ func (m *Streamer) Init(ctx context.Context) error {
return fmt.Errorf("already inited")
}
m.inited = true
m.logger.Debug("open connection", zap.Stringer("endpoint", m.endpoint), zap.Stringers("additional endpoints", m.additionalEndpoints))
m.logger.Debug("open connection", zap.String("host", m.host), zap.Int("port", m.port), zap.Stringers("addresses", m.addresses))

conn, err := m.openConnect(ctx)
if err != nil {
Expand Down Expand Up @@ -1134,28 +1105,36 @@ func (m *Streamer) uploadSftp(filePaths map[string]streamer.File, useSudo bool)
}

// DialCtx ssh.Dial version with context arg
func DialCtx(ctx context.Context, endpoint Endpoint, additionalEndpoints []Endpoint, config *ssh.ClientConfig, logger *zap.Logger) (*ssh.Client, error) {
func DialCtx(ctx context.Context, host string, port int, addresses []net.IP, config *ssh.ClientConfig, logger *zap.Logger) (*ssh.Client, error) {
var err error
var conn net.Conn
var connectedEndpoint Endpoint
endpoints := append([]Endpoint{endpoint}, additionalEndpoints...)
var connectedEndpoint string
var endpoints []string
if len(addresses) != 0 {
endpoints = make([]string, 0, len(addresses))
for _, v := range addresses {
endpoints = append(endpoints, net.JoinHostPort(v.String(), strconv.Itoa(port)))
}
} else {
endpoints = []string{net.JoinHostPort(host, strconv.Itoa(port))}
}
for _, endpoint := range endpoints {
connectedEndpoint = endpoint
logger.Debug("tcp dial", zap.String("address", connectedEndpoint.String()))
conn, err = streamer.TCPDialCtx(ctx, string(endpoint.Network), endpoint.Addr())
logger.Debug("tcp dial", zap.String("address", endpoint))
conn, err = streamer.TCPDialCtx(ctx, "tcp", endpoint)
if err == nil {
connectedEndpoint = endpoint
break
}
// always continue attempts to connect in case of dial failure
logger.Debug("dial failed for endpoint", zap.String("endpoint", endpoint.String()), zap.Error(err))
logger.Debug("dial failed for endpoint", zap.String("endpoint", endpoint), zap.Error(err))
}
if err != nil {
return nil, fmt.Errorf("failed to dial any of given endpoints: %v, last error: %w", endpoint, err)
return nil, fmt.Errorf("failed to dial any of given endpoints: %v, last error: %w", endpoints, err)
}
logger.Debug("tcp ssh", zap.String("address", connectedEndpoint.String()))
res, err := DialConnCtx(ctx, conn, connectedEndpoint.Addr(), config)
logger.Debug("tcp ssh", zap.String("address", connectedEndpoint))
res, err := DialConnCtx(ctx, conn, connectedEndpoint, config)
if err != nil {
return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint.String(), err)
return nil, fmt.Errorf("failed to connect to host %s: %w", connectedEndpoint, err)
}
return res, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/streamer/ssh/ssh_control_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ func resolveHomeDir(path string) string {
return path
}

func dialControlMasterConf(_ context.Context, controlFile string, endpoint Endpoint, conf *ssh.ClientConfig, logger *zap.Logger) (*ControlConn, error) {
params := tssh.NewSshParam(endpoint.Host, strconv.Itoa(endpoint.Port), conf.User, nil)
func dialControlMasterConf(_ context.Context, controlFile string, host string, port int, conf *ssh.ClientConfig, logger *zap.Logger) (*ControlConn, error) {
params := tssh.NewSshParam(host, strconv.Itoa(port), conf.User, nil)
expandedPath, err := tssh.ExpandTokens(controlFile, params, "%CdhijkLlnpru")
if err != nil {
return nil, err
Expand Down
Loading
Loading