Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion downloader/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ func TestAll(t *testing.T) {
{"testAuthSuccess", testAuthSuccess},
{"testAuthErrors", testAuthErrors},
{"testContextCancelledFail", testContextCancelledFail},
{"testContextCancelledPrepareRepo", testContextCancelledPrepareRepo},
{"testWrongEndpointFail", testWrongEndpointFail},
{"testAlreadyDownloadedFail", testAlreadyDownloadedFail},
{"testDownloadConcurrentSuccess", testDownloadConcurrentSuccess},
Expand Down Expand Up @@ -250,7 +251,21 @@ func testContextCancelledFail(t *testing.T, h *testhelper.Helper) {
}
job.SetEndpoints([]string{endPoint(gitProtocol, testRepo)})

require.Equal(t, fmt.Errorf("context canceled"), Download(ctx, job))
require.Equal(t, context.Canceled, Download(ctx, job))
}

// testContextCancelledPrepareRepo
// 1) tries to prepare a repository with a cancelled context. Previously this
// caused a race condition now it should be correct.
func testContextCancelledPrepareRepo(t *testing.T, h *testhelper.Helper) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

testRepo := tests[0].repoIDs[0]
repo, err := PrepareRepository(ctx, h.Lib, "location", testRepo,
endPoint(gitProtocol, testRepo), h.TempFS, "tmp")
require.Error(t, err)
require.Nil(t, repo)
}

// testWrongEndpointFail
Expand Down
59 changes: 41 additions & 18 deletions downloader/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,33 +264,27 @@ func createRootedRepo(
return nil, err
}

done := make(chan struct{})
go func() {
err = recursiveCopy(
"/", repo.FS(),
clonedPath, clonedFS,
)

close(done)
}()

select {
case <-done:
case <-ctx.Done():
err = ctx.Err()
repo.Close()
err = recursiveCopy(ctx, "/", repo.FS(), clonedPath, clonedFS)
if err != nil {
repo = nil
}

return repo, err
}

func recursiveCopy(
ctx context.Context,
dst string,
dstFS billy.Filesystem,
src string,
srcFS billy.Filesystem,
) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

stat, err := srcFS.Stat(src)
if err != nil {
return err
Expand All @@ -311,13 +305,13 @@ func recursiveCopy(
srcPath := filepath.Join(src, file.Name())
dstPath := filepath.Join(dst, file.Name())

err = recursiveCopy(dstPath, dstFS, srcPath, srcFS)
err = recursiveCopy(ctx, dstPath, dstFS, srcPath, srcFS)
if err != nil {
return err
}
}
} else {
err = copyFile(dst, dstFS, src, srcFS, stat.Mode())
err = copyFile(ctx, dst, dstFS, src, srcFS, stat.Mode())
if err != nil {
return err
}
Expand All @@ -327,12 +321,19 @@ func recursiveCopy(
}

func copyFile(
ctx context.Context,
dst string,
dstFS billy.Filesystem,
src string,
srcFS billy.Filesystem,
mode os.FileMode,
) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

_, err := srcFS.Stat(src)
if err != nil {
return err
Expand All @@ -350,7 +351,7 @@ func copyFile(
}
defer fd.Close()

_, err = io.Copy(fd, fo)
_, err = io.Copy(fd, newContextReader(ctx, fo))
if err != nil {
fd.Close()
dstFS.Remove(dst)
Expand All @@ -359,3 +360,25 @@ func copyFile(

return nil
}

type contextReader struct {
reader io.Reader
ctx context.Context
}

func newContextReader(ctx context.Context, reader io.Reader) *contextReader {
return &contextReader{
ctx: ctx,
reader: reader,
}
}

func (c *contextReader) Read(p []byte) (n int, err error) {
select {
case <-c.ctx.Done():
return 0, c.ctx.Err()
default:
}

return c.reader.Read(p)
}