diff --git a/cmd/main.go b/cmd/main.go index 10abd00c..89f1cd0a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -32,14 +32,17 @@ func Run(name string, args []string) error { opts.args = flags.Args() setupLogging(opts) + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + switch { case opts.version: fmt.Fprintf(os.Stdout, "gotestsum version %s\n", version) return nil case opts.watch: - return runWatcher(opts) + return runWatcher(ctx, opts) } - return run(opts) + return run(ctx, opts) } func setupFlags(name string) (*pflag.FlagSet, *options) { @@ -233,7 +236,7 @@ func defaultNoColor() bool { // try to detect these CI environments via their environment variables. // This code is based on https://github.com/jwalton/go-supportscolor if value, exists := os.LookupEnv("CI"); exists { - var ciEnvNames = []string{ + ciEnvNames := []string{ "APPVEYOR", "BUILDKITE", "CIRCLECI", @@ -268,8 +271,8 @@ func setupLogging(opts *options) { color.NoColor = opts.noColor } -func run(opts *options) error { - ctx, cancel := context.WithCancel(context.Background()) +func run(ctx context.Context, opts *options) error { + ctx, cancel := context.WithCancel(ctx) defer cancel() if err := opts.Validate(); err != nil { @@ -435,6 +438,7 @@ type proc struct { signal int32 } +// waiter interface is used to allow testing with mocks. type waiter interface { Wait() error } @@ -464,9 +468,7 @@ func startGoTest(ctx context.Context, dir string, args []string) (*proc, error) } log.Debugf("go test pid: %d", cmd.Process.Pid) - ctx, cancel := context.WithCancel(ctx) newSignalHandler(ctx, cmd.Process.Pid, &p) - p.cmd = &cancelWaiter{cancel: cancel, wrapped: p.cmd} return &p, nil } @@ -510,40 +512,20 @@ func (e exitError) ExitCode() int { const signalExitCode = 128 func newSignalHandler(ctx context.Context, pid int, p *proc) { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - go func() { - defer signal.Stop(c) - - select { - case <-ctx.Done(): + <-ctx.Done() // when ctx is cancelled, find process and kill it + atomic.StoreInt32(&p.signal, int32(syscall.SIGINT)) + proc, err := os.FindProcess(pid) + if err != nil { + log.Errorf("failed to find pid of 'go test': %v", err) return - case s := <-c: - atomic.StoreInt32(&p.signal, int32(s.(syscall.Signal))) - - proc, err := os.FindProcess(pid) - if err != nil { - log.Errorf("failed to find pid of 'go test': %v", err) - return - } - if err := proc.Signal(s); err != nil { - log.Errorf("failed to interrupt 'go test': %v", err) - return + } + if err := proc.Signal(os.Interrupt); err != nil { + if errors.Is(err, os.ErrProcessDone) { + return // process already exited } + log.Errorf("failed to interrupt 'go test': %v", err) + return } }() } - -// cancelWaiter wraps a waiter to cancel the context after the wrapped -// Wait exits. -type cancelWaiter struct { - cancel func() - wrapped waiter -} - -func (w *cancelWaiter) Wait() error { - err := w.wrapped.Wait() - w.cancel() - return err -} diff --git a/cmd/main_e2e_test.go b/cmd/main_e2e_test.go index ff9279cc..83f6aa86 100644 --- a/cmd/main_e2e_test.go +++ b/cmd/main_e2e_test.go @@ -2,6 +2,7 @@ package cmd import ( "bytes" + "context" goversion "go/version" "os" "path/filepath" @@ -11,7 +12,6 @@ import ( "testing" "time" - "gotest.tools/gotestsum/internal/text" "gotest.tools/v3/assert" "gotest.tools/v3/env" "gotest.tools/v3/fs" @@ -19,6 +19,8 @@ import ( "gotest.tools/v3/icmd" "gotest.tools/v3/poll" "gotest.tools/v3/skip" + + "gotest.tools/gotestsum/internal/text" ) func TestMain(m *testing.M) { @@ -41,6 +43,10 @@ func TestE2E_RerunFails(t *testing.T) { args []string expectedErr string } + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + fn := func(t *testing.T, tc testCase) { tmpFile := fs.NewFile(t, t.Name()+"-seedfile", fs.WithContent("0")) defer tmpFile.Remove() @@ -58,7 +64,9 @@ func TestE2E_RerunFails(t *testing.T) { bufStderr := new(bytes.Buffer) opts.stderr = bufStderr - err := run(opts) + err := run(ctx, opts) + // when we expect an error, it may be wrapped so we do a substring match + // rather than an exact match if tc.expectedErr != "" { assert.Error(t, err, tc.expectedErr) } else { @@ -71,7 +79,7 @@ func TestE2E_RerunFails(t *testing.T) { ) golden.Assert(t, out, "e2e/expected/"+expectedFilename(t.Name())) } - var testCases = []testCase{ + testCases := []testCase{ { name: "reruns until success", args: []string{ @@ -215,6 +223,9 @@ func TestE2E_MaxFails_EndTestRun(t *testing.T) { t.Skip("too slow for short run") } + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + tmpFile := fs.NewFile(t, t.Name()+"-seedfile", fs.WithContent("0")) defer tmpFile.Remove() @@ -233,7 +244,7 @@ func TestE2E_MaxFails_EndTestRun(t *testing.T) { bufStderr := new(bytes.Buffer) opts.stderr = bufStderr - err := run(opts) + err := run(ctx, opts) assert.Error(t, err, "ending test run because max failures was reached") out := text.ProcessLines(t, bufStdout, text.OpRemoveSummaryLineElapsedTime, @@ -249,6 +260,9 @@ func TestE2E_IgnoresWarnings(t *testing.T) { } t.Setenv("GITHUB_ACTIONS", "no") + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + flags, opts := setupFlags("gotestsum") args := []string{ "--rerun-fails=1", @@ -264,7 +278,7 @@ func TestE2E_IgnoresWarnings(t *testing.T) { bufStderr := new(bytes.Buffer) opts.stderr = bufStderr - err := run(opts) + err := run(ctx, opts) assert.Error(t, err, "exit status 1") out := text.ProcessLines(t, bufStdout, text.OpRemoveSummaryLineElapsedTime, diff --git a/cmd/main_test.go b/cmd/main_test.go index a6195113..3072259b 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -2,12 +2,14 @@ package cmd import ( "bytes" + "context" "encoding/json" "os" "os/exec" "path/filepath" "runtime" "strings" + "sync" "testing" "github.com/fatih/color" @@ -31,11 +33,15 @@ func TestUsage_WithFlagsFromSetupFlags(t *testing.T) { golden.Assert(t, buf.String(), "gotestsum-help-text") } +var noColorMu sync.Mutex // prevent data race in parallel tests + func patchNoColor(t *testing.T, value bool) { + noColorMu.Lock() orig := color.NoColor color.NoColor = value t.Cleanup(func() { color.NoColor = orig + noColorMu.Unlock() }) } @@ -58,7 +64,7 @@ func TestOptions_Validate_FromFlags(t *testing.T) { } assert.ErrorContains(t, err, tc.expected, "opts: %#v", opts) } - var testCases = []testCase{ + testCases := []testCase{ { name: "no flags", }, @@ -338,6 +344,9 @@ func runCase(t *testing.T, name string, fn func(t *testing.T)) { } func TestRun_RerunFails_WithTooManyInitialFailures(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + jsonFailed := `{"Package": "pkg", "Action": "run"} {"Package": "pkg", "Test": "TestOne", "Action": "run"} {"Package": "pkg", "Test": "TestOne", "Action": "fail"} @@ -367,11 +376,14 @@ func TestRun_RerunFails_WithTooManyInitialFailures(t *testing.T) { stderr: os.Stderr, hideSummary: newHideSummaryValue(), } - err := run(opts) + err := run(ctx, opts) assert.ErrorContains(t, err, "number of test failures (2) exceeds maximum (1)", out.String()) } func TestRun_RerunFails_BuildErrorPreventsRerun(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + jsonFailed := `{"Package": "pkg", "Action": "run"} {"Package": "pkg", "Test": "TestOne", "Action": "run"} {"Package": "pkg", "Test": "TestOne", "Action": "fail"} @@ -401,7 +413,7 @@ func TestRun_RerunFails_BuildErrorPreventsRerun(t *testing.T) { stderr: os.Stderr, hideSummary: newHideSummaryValue(), } - err := run(opts) + err := run(ctx, opts) assert.ErrorContains(t, err, "rerun aborted because previous run had errors", out.String()) } @@ -415,6 +427,8 @@ func TestRun_RerunFails_PanicPreventsRerun(t *testing.T) { {"Package": "pkg", "Test": "TestOne", "Action": "output","Output":"panic: something went wrong\n"} {"Package": "pkg", "Action": "fail"} ` + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) fn := func([]string) *proc { return &proc{ @@ -437,7 +451,7 @@ func TestRun_RerunFails_PanicPreventsRerun(t *testing.T) { stderr: os.Stderr, hideSummary: newHideSummaryValue(), } - err := run(opts) + err := run(ctx, opts) assert.ErrorContains(t, err, "rerun aborted because previous run had a suspected panic", out.String()) } @@ -445,6 +459,9 @@ func TestRun_InputFromStdin(t *testing.T) { stdin := os.Stdin t.Cleanup(func() { os.Stdin = stdin }) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + r, w, err := os.Pipe() assert.NilError(t, err) t.Cleanup(func() { _ = r.Close() }) @@ -466,7 +483,7 @@ func TestRun_InputFromStdin(t *testing.T) { }() stdout := new(bytes.Buffer) - err = run(&options{ + err = run(ctx, &options{ args: []string{"cat"}, format: "testname", hideSummary: newHideSummaryValue(), @@ -484,6 +501,9 @@ func TestRun_JsonFileIsSyncedBeforePostRunCommand(t *testing.T) { input := golden.Get(t, "../../testjson/testdata/input/go-test-json.out") + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + fn := func([]string) *proc { return &proc{ cmd: fakeWaiter{}, @@ -510,7 +530,7 @@ func TestRun_JsonFileIsSyncedBeforePostRunCommand(t *testing.T) { command: []string{"cat", jsonFile}, }, } - err := run(opts) + err := run(ctx, opts) assert.NilError(t, err) expected := string(input) _, actual, _ := strings.Cut(out.String(), "s\n") // remove the DONE line @@ -520,6 +540,9 @@ func TestRun_JsonFileIsSyncedBeforePostRunCommand(t *testing.T) { func TestRun_JsonFileTimingEvents(t *testing.T) { input := golden.Get(t, "../../testjson/testdata/input/go-test-json.out") + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + fn := func([]string) *proc { return &proc{ cmd: fakeWaiter{}, @@ -543,7 +566,7 @@ func TestRun_JsonFileTimingEvents(t *testing.T) { hideSummary: &hideSummaryValue{value: testjson.SummarizeNone}, jsonFileTimingEvents: jsonFileTiming, } - err := run(opts) + err := run(ctx, opts) assert.NilError(t, err) raw, err := os.ReadFile(jsonFileTiming) diff --git a/cmd/rerunfails_test.go b/cmd/rerunfails_test.go index 96d4bccf..239dd751 100644 --- a/cmd/rerunfails_test.go +++ b/cmd/rerunfails_test.go @@ -68,7 +68,7 @@ func TestGoTestRunFlagFromTestCases(t *testing.T) { assert.Equal(t, actual, tc.expected) } - var testCases = map[string]testCase{ + testCases := map[string]testCase{ "root test case": { input: "TestOne", expected: "-test.run=^TestOne$", diff --git a/cmd/watch.go b/cmd/watch.go index 91e7a6a1..be78e7c1 100644 --- a/cmd/watch.go +++ b/cmd/watch.go @@ -11,10 +11,7 @@ import ( "gotest.tools/gotestsum/testjson" ) -func runWatcher(opts *options) error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - +func runWatcher(ctx context.Context, opts *options) error { w := &watchRuns{opts: *opts} return filewatcher.Watch(ctx, opts.packages, opts.watchClear, w.run) } @@ -24,7 +21,7 @@ type watchRuns struct { prevExec *testjson.Execution } -func (w *watchRuns) run(event filewatcher.Event) error { +func (w *watchRuns) run(ctx context.Context, event filewatcher.Event) error { if event.Debug { path, cleanup, err := delveInitFile(w.prevExec) if err != nil { @@ -53,7 +50,7 @@ func (w *watchRuns) run(event filewatcher.Event) error { opts.packages = append(opts.packages, event.Args...) var err error - if w.prevExec, err = runSingle(&opts, dir); !IsExitCoder(err) { + if w.prevExec, err = runSingle(ctx, &opts, dir); !IsExitCoder(err) { return err } return nil @@ -62,8 +59,8 @@ func (w *watchRuns) run(event filewatcher.Event) error { // runSingle is similar to run. It doesn't support rerun-fails. It may be // possible to share runSingle with run, but the defer close on the handler // would require at least 3 return values, so for now it is a copy. -func runSingle(opts *options, dir string) (*testjson.Execution, error) { - ctx, cancel := context.WithCancel(context.Background()) +func runSingle(ctx context.Context, opts *options, dir string) (*testjson.Execution, error) { + ctx, cancel := context.WithCancel(ctx) defer cancel() if err := opts.Validate(); err != nil { diff --git a/internal/filewatcher/watch.go b/internal/filewatcher/watch.go index 7a3afdd8..fddcc91e 100644 --- a/internal/filewatcher/watch.go +++ b/internal/filewatcher/watch.go @@ -38,7 +38,7 @@ type Event struct { // Watch dirs for filesystem events, and run tests when .go files are saved. // //nolint:gocyclo -func Watch(ctx context.Context, dirs []string, clearScreen bool, run func(Event) error) error { +func Watch(ctx context.Context, dirs []string, clearScreen bool, run func(context.Context, Event) error) error { watcher, err := fsnotify.NewWatcher() if err != nil { return fmt.Errorf("failed to create file watcher: %w", err) @@ -80,7 +80,7 @@ func Watch(ctx context.Context, dirs []string, clearScreen bool, run func(Event) } term.Reset() - if err := h.runTests(event); err != nil { + if err := h.runTests(ctx, event); err != nil { return fmt.Errorf("failed to rerun tests for %v: %v", event.PkgPath, err) } term.Start() @@ -94,7 +94,7 @@ func Watch(ctx context.Context, dirs []string, clearScreen bool, run func(Event) continue } - if err := h.handleEvent(event); err != nil { + if err := h.handleEvent(ctx, event); err != nil { return fmt.Errorf("failed to run tests for %v: %v", event.Name, err) } @@ -236,12 +236,12 @@ type fsEventHandler struct { last time.Time lastPath string clearScreen bool - fn func(opts Event) error + fn func(ctx context.Context, opts Event) error } var floodThreshold = 250 * time.Millisecond -func (h *fsEventHandler) handleEvent(event fsnotify.Event) error { +func (h *fsEventHandler) handleEvent(ctx context.Context, event fsnotify.Event) error { if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Rename) == 0 { return nil } @@ -254,10 +254,10 @@ func (h *fsEventHandler) handleEvent(event fsnotify.Event) error { log.Debugf("skipping event received less than %v after the previous", floodThreshold) return nil } - return h.runTests(Event{PkgPath: "./" + filepath.Dir(event.Name)}) + return h.runTests(ctx, Event{PkgPath: "./" + filepath.Dir(event.Name)}) } -func (h *fsEventHandler) runTests(opts Event) error { +func (h *fsEventHandler) runTests(ctx context.Context, opts Event) error { if opts.useLastPath { opts.PkgPath = h.lastPath } @@ -268,7 +268,7 @@ func (h *fsEventHandler) runTests(opts Event) error { fmt.Printf("\nRunning tests in %v\n", opts.PkgPath) - if err := h.fn(opts); err != nil { + if err := h.fn(ctx, opts); err != nil { return err } h.last = time.Now() diff --git a/internal/filewatcher/watch_test.go b/internal/filewatcher/watch_test.go index 65f75451..4d85a356 100644 --- a/internal/filewatcher/watch_test.go +++ b/internal/filewatcher/watch_test.go @@ -1,6 +1,7 @@ package filewatcher import ( + "context" "fmt" "path/filepath" "testing" @@ -22,13 +23,15 @@ func TestFSEventHandler_HandleEvent(t *testing.T) { fn := func(t *testing.T, tc testCase) { var ran bool - run := func(Event) error { + run := func(context.Context, Event) error { ran = true return nil } + ctx := context.Background() + h := fsEventHandler{last: tc.last, fn: run} - err := h.handleEvent(tc.event) + err := h.handleEvent(ctx, tc.event) assert.NilError(t, err) assert.Equal(t, ran, tc.expectedRun) if tc.expectedRun { diff --git a/internal/filewatcher/watch_unix_test.go b/internal/filewatcher/watch_unix_test.go index 5e027302..dc3067d2 100644 --- a/internal/filewatcher/watch_unix_test.go +++ b/internal/filewatcher/watch_unix_test.go @@ -25,7 +25,8 @@ func TestWatch(t *testing.T) { patchFloodThreshold(t, 0) chEvents := make(chan Event, 1) - capture := func(event Event) error { + capture := func(ctx context.Context, event Event) error { + _ = ctx chEvents <- event return nil }