diff --git a/pkg/streamer/ssh/ssh.go b/pkg/streamer/ssh/ssh.go index a5101c2..ec40609 100644 --- a/pkg/streamer/ssh/ssh.go +++ b/pkg/streamer/ssh/ssh.go @@ -430,37 +430,36 @@ func (m *Streamer) Cmd(ctx context.Context, cmd string) (gcmd.CmdRes, error) { sessionTemplate.session.Signal(ssh.SIGKILL) _ = sessionTemplate.session.Close() }) - err = sessionTemplate.session.Run(cmd) + err = sessionTemplate.session.Start(cmd) + if err != nil { + cancel() + return nil, fmt.Errorf("start cmd error: %w", err) + } + stdoutBytes, stderrBytes, copyErr := copySessionOutput(ctx, sessionTemplate.stdout, sessionTemplate.stderr) + waitErr := sessionTemplate.session.Wait() cancel() onSessionCloseErr := m.onSessionCloseCallbacks(sessionTemplate.session) if onSessionCloseErr != nil { - m.logger.Error("onSessionCloseCallbacks error %w", zap.Error(err)) + m.logger.Error("onSessionCloseCallbacks error %w", zap.Error(waitErr)) } + status := 0 isStatusGettingOk := false var execErr error - if err != nil { + if waitErr != nil { var errCode *ssh.ExitError - if errors.As(err, &errCode) { + if errors.As(waitErr, &errCode) { status = errCode.ExitStatus() isStatusGettingOk = true } else { - execErr = err + execErr = waitErr } } else { isStatusGettingOk = true } - - var stdoutBuffer, stderrBuffer bytes.Buffer - _, err = io.Copy(&stdoutBuffer, sessionTemplate.stdout) - if err != nil { - return nil, fmt.Errorf("failed to copy stdout: %w", err) - } - _, err = io.Copy(&stderrBuffer, sessionTemplate.stderr) - if err != nil { - return nil, fmt.Errorf("failed to copy stderr: %w", err) + if copyErr != nil { + return nil, copyErr } - stdoutBytes, stderrBytes := stdoutBuffer.Bytes(), stderrBuffer.Bytes() var res gcmd.CmdRes if isStatusGettingOk { res = gcmd.NewCmdResFull(stdoutBytes, stderrBytes, status, nil) @@ -477,6 +476,29 @@ func (m *Streamer) Cmd(ctx context.Context, cmd string) (gcmd.CmdRes, error) { return res, nil } +func copySessionOutput(ctx context.Context, stdout io.Reader, stderr io.Reader) ([]byte, []byte, error) { + var stdoutBuffer, stderrBuffer bytes.Buffer + eg, _ := errgroup.WithContext(ctx) + eg.Go(func() error { + _, err := io.Copy(&stdoutBuffer, stdout) + if err != nil { + return fmt.Errorf("failed to copy stdout: %w", err) + } + return nil + }) + eg.Go(func() error { + _, err := io.Copy(&stderrBuffer, stderr) + if err != nil { + return fmt.Errorf("failed to copy stderr: %w", err) + } + return nil + }) + if err := eg.Wait(); err != nil { + return nil, nil, err + } + return stdoutBuffer.Bytes(), stderrBuffer.Bytes(), nil +} + func (m *Streamer) GetConfig(ctx context.Context) (*ssh.ClientConfig, error) { creds := m.credentials if m.credentialsInterceptor != nil {