Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 20 additions & 38 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@ func Run(name string, args []string) error {
opts.args = flags.Args()
setupLogging(opts)

ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()

switch {
case opts.version:
fmt.Fprintf(os.Stdout, "gotestsum version %s\n", version)
return nil
case opts.watch:
return runWatcher(opts)
return runWatcher(ctx, opts)
}
return run(opts)
return run(ctx, opts)
}

func setupFlags(name string) (*pflag.FlagSet, *options) {
Expand Down Expand Up @@ -233,7 +236,7 @@ func defaultNoColor() bool {
// try to detect these CI environments via their environment variables.
// This code is based on https://github.com/jwalton/go-supportscolor
if value, exists := os.LookupEnv("CI"); exists {
var ciEnvNames = []string{
ciEnvNames := []string{
"APPVEYOR",
"BUILDKITE",
"CIRCLECI",
Expand Down Expand Up @@ -268,8 +271,8 @@ func setupLogging(opts *options) {
color.NoColor = opts.noColor
}

func run(opts *options) error {
ctx, cancel := context.WithCancel(context.Background())
func run(ctx context.Context, opts *options) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

if err := opts.Validate(); err != nil {
Expand Down Expand Up @@ -435,6 +438,7 @@ type proc struct {
signal int32
}

// waiter interface is used to allow testing with mocks.
type waiter interface {
Wait() error
}
Expand Down Expand Up @@ -464,9 +468,7 @@ func startGoTest(ctx context.Context, dir string, args []string) (*proc, error)
}
log.Debugf("go test pid: %d", cmd.Process.Pid)

ctx, cancel := context.WithCancel(ctx)
newSignalHandler(ctx, cmd.Process.Pid, &p)
p.cmd = &cancelWaiter{cancel: cancel, wrapped: p.cmd}
return &p, nil
}

Expand Down Expand Up @@ -510,40 +512,20 @@ func (e exitError) ExitCode() int {
const signalExitCode = 128

func newSignalHandler(ctx context.Context, pid int, p *proc) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)

go func() {
defer signal.Stop(c)

select {
case <-ctx.Done():
<-ctx.Done() // when ctx is cancelled, find process and kill it
atomic.StoreInt32(&p.signal, int32(syscall.SIGINT))
proc, err := os.FindProcess(pid)
if err != nil {
log.Errorf("failed to find pid of 'go test': %v", err)
return
case s := <-c:
atomic.StoreInt32(&p.signal, int32(s.(syscall.Signal)))

proc, err := os.FindProcess(pid)
if err != nil {
log.Errorf("failed to find pid of 'go test': %v", err)
return
}
if err := proc.Signal(s); err != nil {
log.Errorf("failed to interrupt 'go test': %v", err)
return
}
if err := proc.Signal(os.Interrupt); err != nil {
if errors.Is(err, os.ErrProcessDone) {
return // process already exited
}
log.Errorf("failed to interrupt 'go test': %v", err)
return
}
}()
}

// cancelWaiter wraps a waiter to cancel the context after the wrapped
// Wait exits.
type cancelWaiter struct {
cancel func()
wrapped waiter
}

func (w *cancelWaiter) Wait() error {
err := w.wrapped.Wait()
w.cancel()
return err
}
24 changes: 19 additions & 5 deletions cmd/main_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"bytes"
"context"
goversion "go/version"
"os"
"path/filepath"
Expand All @@ -11,14 +12,15 @@ import (
"testing"
"time"

"gotest.tools/gotestsum/internal/text"
"gotest.tools/v3/assert"
"gotest.tools/v3/env"
"gotest.tools/v3/fs"
"gotest.tools/v3/golden"
"gotest.tools/v3/icmd"
"gotest.tools/v3/poll"
"gotest.tools/v3/skip"

"gotest.tools/gotestsum/internal/text"
)

func TestMain(m *testing.M) {
Expand All @@ -41,6 +43,10 @@ func TestE2E_RerunFails(t *testing.T) {
args []string
expectedErr string
}

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

fn := func(t *testing.T, tc testCase) {
tmpFile := fs.NewFile(t, t.Name()+"-seedfile", fs.WithContent("0"))
defer tmpFile.Remove()
Expand All @@ -58,7 +64,9 @@ func TestE2E_RerunFails(t *testing.T) {
bufStderr := new(bytes.Buffer)
opts.stderr = bufStderr

err := run(opts)
err := run(ctx, opts)
// when we expect an error, it may be wrapped so we do a substring match
// rather than an exact match
if tc.expectedErr != "" {
assert.Error(t, err, tc.expectedErr)
} else {
Expand All @@ -71,7 +79,7 @@ func TestE2E_RerunFails(t *testing.T) {
)
golden.Assert(t, out, "e2e/expected/"+expectedFilename(t.Name()))
}
var testCases = []testCase{
testCases := []testCase{
{
name: "reruns until success",
args: []string{
Expand Down Expand Up @@ -215,6 +223,9 @@ func TestE2E_MaxFails_EndTestRun(t *testing.T) {
t.Skip("too slow for short run")
}

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

tmpFile := fs.NewFile(t, t.Name()+"-seedfile", fs.WithContent("0"))
defer tmpFile.Remove()

Expand All @@ -233,7 +244,7 @@ func TestE2E_MaxFails_EndTestRun(t *testing.T) {
bufStderr := new(bytes.Buffer)
opts.stderr = bufStderr

err := run(opts)
err := run(ctx, opts)
assert.Error(t, err, "ending test run because max failures was reached")
out := text.ProcessLines(t, bufStdout,
text.OpRemoveSummaryLineElapsedTime,
Expand All @@ -249,6 +260,9 @@ func TestE2E_IgnoresWarnings(t *testing.T) {
}
t.Setenv("GITHUB_ACTIONS", "no")

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

flags, opts := setupFlags("gotestsum")
args := []string{
"--rerun-fails=1",
Expand All @@ -264,7 +278,7 @@ func TestE2E_IgnoresWarnings(t *testing.T) {
bufStderr := new(bytes.Buffer)
opts.stderr = bufStderr

err := run(opts)
err := run(ctx, opts)
assert.Error(t, err, "exit status 1")
out := text.ProcessLines(t, bufStdout,
text.OpRemoveSummaryLineElapsedTime,
Expand Down
37 changes: 30 additions & 7 deletions cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package cmd

import (
"bytes"
"context"
"encoding/json"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"

"github.com/fatih/color"
Expand All @@ -31,11 +33,15 @@ func TestUsage_WithFlagsFromSetupFlags(t *testing.T) {
golden.Assert(t, buf.String(), "gotestsum-help-text")
}

var noColorMu sync.Mutex // prevent data race in parallel tests

func patchNoColor(t *testing.T, value bool) {
noColorMu.Lock()
orig := color.NoColor
color.NoColor = value
t.Cleanup(func() {
color.NoColor = orig
noColorMu.Unlock()
})
}

Expand All @@ -58,7 +64,7 @@ func TestOptions_Validate_FromFlags(t *testing.T) {
}
assert.ErrorContains(t, err, tc.expected, "opts: %#v", opts)
}
var testCases = []testCase{
testCases := []testCase{
{
name: "no flags",
},
Expand Down Expand Up @@ -338,6 +344,9 @@ func runCase(t *testing.T, name string, fn func(t *testing.T)) {
}

func TestRun_RerunFails_WithTooManyInitialFailures(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

jsonFailed := `{"Package": "pkg", "Action": "run"}
{"Package": "pkg", "Test": "TestOne", "Action": "run"}
{"Package": "pkg", "Test": "TestOne", "Action": "fail"}
Expand Down Expand Up @@ -367,11 +376,14 @@ func TestRun_RerunFails_WithTooManyInitialFailures(t *testing.T) {
stderr: os.Stderr,
hideSummary: newHideSummaryValue(),
}
err := run(opts)
err := run(ctx, opts)
assert.ErrorContains(t, err, "number of test failures (2) exceeds maximum (1)", out.String())
}

func TestRun_RerunFails_BuildErrorPreventsRerun(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

jsonFailed := `{"Package": "pkg", "Action": "run"}
{"Package": "pkg", "Test": "TestOne", "Action": "run"}
{"Package": "pkg", "Test": "TestOne", "Action": "fail"}
Expand Down Expand Up @@ -401,7 +413,7 @@ func TestRun_RerunFails_BuildErrorPreventsRerun(t *testing.T) {
stderr: os.Stderr,
hideSummary: newHideSummaryValue(),
}
err := run(opts)
err := run(ctx, opts)
assert.ErrorContains(t, err, "rerun aborted because previous run had errors", out.String())
}

Expand All @@ -415,6 +427,8 @@ func TestRun_RerunFails_PanicPreventsRerun(t *testing.T) {
{"Package": "pkg", "Test": "TestOne", "Action": "output","Output":"panic: something went wrong\n"}
{"Package": "pkg", "Action": "fail"}
`
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

fn := func([]string) *proc {
return &proc{
Expand All @@ -437,14 +451,17 @@ func TestRun_RerunFails_PanicPreventsRerun(t *testing.T) {
stderr: os.Stderr,
hideSummary: newHideSummaryValue(),
}
err := run(opts)
err := run(ctx, opts)
assert.ErrorContains(t, err, "rerun aborted because previous run had a suspected panic", out.String())
}

func TestRun_InputFromStdin(t *testing.T) {
stdin := os.Stdin
t.Cleanup(func() { os.Stdin = stdin })

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

r, w, err := os.Pipe()
assert.NilError(t, err)
t.Cleanup(func() { _ = r.Close() })
Expand All @@ -466,7 +483,7 @@ func TestRun_InputFromStdin(t *testing.T) {
}()

stdout := new(bytes.Buffer)
err = run(&options{
err = run(ctx, &options{
args: []string{"cat"},
format: "testname",
hideSummary: newHideSummaryValue(),
Expand All @@ -484,6 +501,9 @@ func TestRun_JsonFileIsSyncedBeforePostRunCommand(t *testing.T) {

input := golden.Get(t, "../../testjson/testdata/input/go-test-json.out")

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

fn := func([]string) *proc {
return &proc{
cmd: fakeWaiter{},
Expand All @@ -510,7 +530,7 @@ func TestRun_JsonFileIsSyncedBeforePostRunCommand(t *testing.T) {
command: []string{"cat", jsonFile},
},
}
err := run(opts)
err := run(ctx, opts)
assert.NilError(t, err)
expected := string(input)
_, actual, _ := strings.Cut(out.String(), "s\n") // remove the DONE line
Expand All @@ -520,6 +540,9 @@ func TestRun_JsonFileIsSyncedBeforePostRunCommand(t *testing.T) {
func TestRun_JsonFileTimingEvents(t *testing.T) {
input := golden.Get(t, "../../testjson/testdata/input/go-test-json.out")

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

fn := func([]string) *proc {
return &proc{
cmd: fakeWaiter{},
Expand All @@ -543,7 +566,7 @@ func TestRun_JsonFileTimingEvents(t *testing.T) {
hideSummary: &hideSummaryValue{value: testjson.SummarizeNone},
jsonFileTimingEvents: jsonFileTiming,
}
err := run(opts)
err := run(ctx, opts)
assert.NilError(t, err)

raw, err := os.ReadFile(jsonFileTiming)
Expand Down
2 changes: 1 addition & 1 deletion cmd/rerunfails_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestGoTestRunFlagFromTestCases(t *testing.T) {
assert.Equal(t, actual, tc.expected)
}

var testCases = map[string]testCase{
testCases := map[string]testCase{
"root test case": {
input: "TestOne",
expected: "-test.run=^TestOne$",
Expand Down
Loading