Skip to content

Commit dc239eb

Browse files
Make pgroll pull pull only migrations that don't exist in target directory (#811)
Make the `pgroll pull` command pull only those migrations from the target database that are missing in the local directory. This ensures that existing migrations in the local directory are not overwritten or modified when pulling from the target database, for example with formatting or order-of-field changes.
1 parent e028d24 commit dc239eb

File tree

5 files changed

+361
-61
lines changed

5 files changed

+361
-61
lines changed

cmd/pull.go

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ package cmd
44

55
import (
66
"fmt"
7-
8-
"github.com/xataio/pgroll/cmd/flags"
9-
"github.com/xataio/pgroll/pkg/state"
7+
"os"
8+
"path/filepath"
109

1110
"github.com/spf13/cobra"
11+
"github.com/xataio/pgroll/pkg/migrations"
1212
)
1313

1414
func pullCmd() *cobra.Command {
@@ -27,34 +27,51 @@ func pullCmd() *cobra.Command {
2727
ctx := cmd.Context()
2828
targetDir := args[0]
2929

30-
state, err := state.New(ctx, flags.PostgresURL(), flags.StateSchema())
30+
m, err := NewRoll(ctx)
3131
if err != nil {
3232
return err
3333
}
34-
defer state.Close()
34+
defer m.Close()
3535

3636
// Ensure that pgroll is initialized
37-
ok, err := state.IsInitialized(cmd.Context())
37+
ok, err := m.State().IsInitialized(cmd.Context())
3838
if err != nil {
3939
return err
4040
}
4141
if !ok {
4242
return errPGRollNotInitialized
4343
}
4444

45-
migs, err := state.SchemaHistory(ctx, flags.Schema())
45+
// Ensure that the target directory is valid, creating it if it doesn't
46+
// exist
47+
_, err = os.Stat(targetDir)
48+
if err != nil {
49+
if os.IsNotExist(err) {
50+
err := os.MkdirAll(targetDir, 0o755)
51+
if err != nil {
52+
return fmt.Errorf("failed to create target directory: %w", err)
53+
}
54+
} else {
55+
return fmt.Errorf("failed to stat directory: %w", err)
56+
}
57+
}
58+
59+
// Get the list of missing migrations (those that have been applied to
60+
// the target database but are missing in the local directory).
61+
migs, err := m.MissingMigrations(ctx, os.DirFS(targetDir))
4662
if err != nil {
47-
return fmt.Errorf("failed to read schema history: %w", err)
63+
return fmt.Errorf("failed to read migrations from target directory: %w", err)
4864
}
4965

66+
// Write the missing migrations to the target directory
5067
for i, mig := range migs {
5168
prefix := ""
5269
if withPrefixes {
5370
prefix = fmt.Sprintf("%04d", i+1) + "_"
5471
}
55-
err := mig.WriteToFile(targetDir, prefix, useJSON)
72+
err := writeMigrationToFile(mig, targetDir, prefix, useJSON)
5673
if err != nil {
57-
return fmt.Errorf("failed to write migration %q: %w", mig.Migration.Name, err)
74+
return fmt.Errorf("failed to write migration %q: %w", mig.Name, err)
5875
}
5976
}
6077
return nil
@@ -66,3 +83,33 @@ func pullCmd() *cobra.Command {
6683

6784
return pullCmd
6885
}
86+
87+
// WriteToFile writes the migration to a file in `targetDir`, prefixing the
88+
// filename with `prefix`. The output format defaults to YAML, but can
89+
// be changed to JSON by setting `useJSON` to true.
90+
func writeMigrationToFile(m *migrations.Migration, targetDir, prefix string, useJSON bool) error {
91+
err := os.MkdirAll(targetDir, 0o755)
92+
if err != nil {
93+
return err
94+
}
95+
96+
suffix := "yaml"
97+
if useJSON {
98+
suffix = "json"
99+
}
100+
101+
fileName := fmt.Sprintf("%s%s.%s", prefix, m.Name, suffix)
102+
filePath := filepath.Join(targetDir, fileName)
103+
104+
file, err := os.Create(filePath)
105+
if err != nil {
106+
return err
107+
}
108+
defer file.Close()
109+
110+
if useJSON {
111+
return m.WriteAsJSON(file)
112+
} else {
113+
return m.WriteAsYAML(file)
114+
}
115+
}

pkg/migrations/migrations.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ package migrations
44

55
import (
66
"context"
7+
"encoding/json"
78
"fmt"
9+
"io"
810

911
_ "github.com/lib/pq"
12+
"sigs.k8s.io/yaml"
1013

1114
"github.com/xataio/pgroll/pkg/db"
1215
"github.com/xataio/pgroll/pkg/schema"
@@ -102,3 +105,22 @@ func (m *Migration) ContainsRawSQLOperation() bool {
102105
}
103106
return false
104107
}
108+
109+
// WriteAsJSON writes the migration to the given writer in JSON format
110+
func (m *Migration) WriteAsJSON(w io.Writer) error {
111+
encoder := json.NewEncoder(w)
112+
encoder.SetIndent("", " ")
113+
114+
return encoder.Encode(m)
115+
}
116+
117+
// WriteAsYAML writes the migration to the given writer in YAML format
118+
func (m *Migration) WriteAsYAML(w io.Writer) error {
119+
yml, err := yaml.Marshal(m)
120+
if err != nil {
121+
return err
122+
}
123+
124+
_, err = w.Write(yml)
125+
return err
126+
}

pkg/roll/missing.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package roll
4+
5+
import (
6+
"context"
7+
"fmt"
8+
"io/fs"
9+
10+
"github.com/xataio/pgroll/pkg/migrations"
11+
)
12+
13+
// MissingMigrations returns the slice of migrations that have been applied to
14+
// the target database but are missing from the local migrations directory
15+
// `dir`.
16+
func (m *Roll) MissingMigrations(ctx context.Context, dir fs.FS) ([]*migrations.Migration, error) {
17+
// Determine the latest version of the database
18+
latestVersion, err := m.State().LatestVersion(ctx, m.Schema())
19+
if err != nil {
20+
return nil, fmt.Errorf("determining latest version: %w", err)
21+
}
22+
23+
// If no migrations are applied, return a nil slice
24+
if latestVersion == nil {
25+
return nil, nil
26+
}
27+
28+
// Collect all migration files from the directory
29+
files, err := migrations.CollectFilesFromDir(dir)
30+
if err != nil {
31+
return nil, fmt.Errorf("reading migration files: %w", err)
32+
}
33+
34+
// Create a set of local migration names for fast lookup
35+
localMigNames := make(map[string]struct{}, len(files))
36+
for _, file := range files {
37+
mig, err := migrations.ReadMigration(dir, file)
38+
if err != nil {
39+
return nil, fmt.Errorf("reading migration file %s: %w", file, err)
40+
}
41+
localMigNames[mig.Name] = struct{}{}
42+
}
43+
44+
// Get the full schema history from the database
45+
history, err := m.State().SchemaHistory(ctx, m.Schema())
46+
if err != nil {
47+
return nil, fmt.Errorf("reading schema history: %w", err)
48+
}
49+
50+
// Find all migrations that have been applied to the database but are missing
51+
// from the local directory
52+
migs := make([]*migrations.Migration, 0, len(history))
53+
for _, h := range history {
54+
if _, ok := localMigNames[h.Migration.Name]; ok {
55+
continue
56+
}
57+
migs = append(migs, &h.Migration)
58+
}
59+
60+
return migs, nil
61+
}

0 commit comments

Comments
 (0)