diff --git a/internal/workflow/signals.go b/internal/workflow/signals.go index 4e34aff2..186a1411 100644 --- a/internal/workflow/signals.go +++ b/internal/workflow/signals.go @@ -5,13 +5,16 @@ package workflow import ( "context" + "errors" "fmt" "log/slog" "os" "os/exec" "os/signal" "path/filepath" + "strconv" "strings" + "sync" "syscall" "time" @@ -57,74 +60,87 @@ func configureSignalHandler(myTargets []target.Target, statusFunc progress.Multi sigChannel := make(chan os.Signal, 1) signal.Notify(sigChannel, syscall.SIGINT, syscall.SIGTERM) go func() { + // wait for a signal sig := <-sigChannel slog.Debug("received signal", slog.String("signal", sig.String())) // The controller script is run in its own process group, so we need to send the signal // directly to the PID of the controller. For every target, look for the controller - // PID file and send SIGINT to it. + // PID file and send SIGINT to it, then wait for it to exit concurrently. + var wg sync.WaitGroup for _, t := range myTargets { if statusFunc != nil { _ = statusFunc(t.GetName(), "Signal received, cleaning up...") } pidFilePath := filepath.Join(t.GetTempDirectory(), script.ControllerPIDFileName) - stdout, _, exitcode, err := t.RunCommandEx(exec.Command("cat", pidFilePath), 5, false, true) // #nosec G204 + stdout, _, _, err := t.RunCommandEx(exec.Command("cat", pidFilePath), 5, false, true) // #nosec G204 + // if there's an error and the error type is exec.ExitError, the file likely doesn't exist + // so we can skip sending the signal to this target if err != nil { - slog.Error("error retrieving target controller PID", slog.String("target", t.GetName()), slog.String("error", err.Error())) + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + slog.Debug("target controller PID file not found, assuming script has already exited", slog.String("target", t.GetName())) + continue + } + slog.Error("failed to retrieve target controller PID", slog.String("target", t.GetName()), slog.String("error", err.Error())) continue } - if exitcode == 0 { - pidStr := strings.TrimSpace(stdout) - err = signalProcessOnTarget(t, pidStr, "SIGINT") - if err != nil { - slog.Error("error sending SIGINT signal to target controller", slog.String("target", t.GetName()), slog.String("error", err.Error())) - } + pid := strings.TrimSpace(stdout) + // confirm pid is a valid integer + if _, err := strconv.Atoi(pid); err != nil { + slog.Error("invalid PID retrieved from target controller PID file", slog.String("target", t.GetName()), slog.String("pid", pid), slog.String("error", err.Error())) + continue } - } - // now wait until all controller scripts have exited - slog.Debug("waiting for controller scripts to exit") - for _, t := range myTargets { - // create a per-target timeout context - targetTimeout := 10 * time.Second - ctx, cancel := context.WithTimeout(context.Background(), targetTimeout) - timedOut := false - pidFilePath := filepath.Join(t.GetTempDirectory(), script.ControllerPIDFileName) - // read the pid file - stdout, _, exitcode, err := t.RunCommandEx(exec.Command("cat", pidFilePath), 5, false, true) // #nosec G204 - if err != nil || exitcode != 0 { - slog.Debug("target controller PID file no longer exists, assuming script has exited or is in the process of exiting", slog.String("target", t.GetName())) - cancel() + // send SIGINT to the controller process on the target + slog.Debug("signaling target controller process with SIGINT", slog.String("target", t.GetName()), slog.String("pid", pid)) + err = signalProcessOnTarget(t, pid, "SIGINT") + if err != nil { + slog.Error("failed to send SIGINT signal to target controller", slog.String("target", t.GetName()), slog.String("error", err.Error())) continue } - pidStr := strings.TrimSpace(stdout) - for { - // determine if the process still exists - _, _, exitcode, err = t.RunCommandEx(exec.Command("ps", "-p", pidStr), 5, false, true) // #nosec G204 - if err != nil || exitcode != 0 { - slog.Debug("target controller process no longer exists", slog.String("target", t.GetName())) - break - } - // check for timeout - select { - case <-ctx.Done(): - timedOut = true - default: - } - if timedOut { - if statusFunc != nil { - _ = statusFunc(t.GetName(), "cleanup timeout exceeded, sending kill signal") - } - slog.Warn("signal handler cleanup timeout exceeded for target, sending SIGKILL", slog.String("target", t.GetName())) - err = signalProcessOnTarget(t, pidStr, "SIGKILL") + // spawn a goroutine to wait for this target's controller to exit + wg.Add(1) + go func(tgt target.Target, pid string) { + defer wg.Done() + // create a per-target timeout context + targetTimeout := 10 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), targetTimeout) + defer cancel() + timedOut := false + for { + // determine if the process still exists + _, _, _, err := tgt.RunCommandEx(exec.Command("ps", "-p", pid), 5, false, true) // #nosec G204 if err != nil { - slog.Error("error sending SIGKILL signal to target controller", slog.String("target", t.GetName()), slog.String("error", err.Error())) + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + slog.Debug("target controller process no longer exists", slog.String("target", tgt.GetName())) + break + } + slog.Error("failed to check target controller process", slog.String("target", tgt.GetName()), slog.String("error", err.Error())) + break } - break + // check for timeout + select { + case <-ctx.Done(): + timedOut = true + default: + } + if timedOut { + if statusFunc != nil { + _ = statusFunc(tgt.GetName(), "cleanup timeout exceeded, sending kill signal") + } + slog.Warn("signal handler cleanup timeout exceeded for target, sending SIGKILL", slog.String("target", tgt.GetName())) + err := signalProcessOnTarget(tgt, pid, "SIGKILL") + if err != nil { + slog.Error("failed to send SIGKILL signal to target controller", slog.String("target", tgt.GetName()), slog.String("error", err.Error())) + } + break + } + // sleep for a short time before checking again + time.Sleep(500 * time.Millisecond) } - // sleep for a short time before checking again - time.Sleep(500 * time.Millisecond) - } - cancel() + }(t, pid) } + wg.Wait() // Race condition between the controller script deleting its PID file and it truly exiting. // Future work: reconsider decision to have the controller script delete its own PID file. @@ -134,20 +150,20 @@ func configureSignalHandler(myTargets []target.Target, statusFunc progress.Multi time.Sleep(500 * time.Millisecond) // send SIGINT to perfspect's remaining children, if any - myPid := os.Getpid() - children, err := util.GetChildren(myPid) + perfspectPid := os.Getpid() + children, err := util.GetChildren(perfspectPid) if err != nil { - slog.Error("error retrieving child processes", slog.String("error", err.Error())) + slog.Error("failed to retrieve perfspect's child processes", slog.String("error", err.Error())) return } if len(children) == 0 { - slog.Debug("no child processes to signal") + slog.Debug("perfspect has no child processes to signal") return } slog.Debug("signaling child processes", slog.String("child PIDs", fmt.Sprintf("%v", children))) err = util.SignalChildren(syscall.SIGINT) if err != nil { - slog.Error("error sending signal to children", slog.String("error", err.Error())) + slog.Error("failed to send SIGINT signal to perfspect's child processes", slog.String("error", err.Error())) } }() }