Skip to content

Commit 1f6f49b

Browse files
Add a sql2pgroll package to convert SQL to pgroll migrations (#502)
Add a `sql2pgroll` package to convert SQL to `pgroll` migrations. Add a (hidden for now) `pgroll sql` command that uses the package to convert SQL strings on the command line to `pgroll` migrations. The `sql2pgroll` package is incomplete, with almost all SQL falling back to conversion using raw SQL migrations. Only some `CREATE TABLE` statements and the `ALTER TABLE ... ALTER COLUMN ... SET NOT NULL` statement are currently handled. ```bash $ pgroll sql "create table foo(a serial primary key, b text unique)" ``` ```json [ { "create_table": { "columns": [ { "name": "a", "pk": true, "type": "serial" }, { "name": "b", "nullable": true, "type": "text", "unique": true } ], "name": "foo" } } ] ``` Part of #504
1 parent 43a6d97 commit 1f6f49b

File tree

11 files changed

+476
-0
lines changed

11 files changed

+476
-0
lines changed

cmd/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ func Execute() error {
7979
rootCmd.AddCommand(migrateCmd())
8080
rootCmd.AddCommand(pullCmd())
8181
rootCmd.AddCommand(latestCmd())
82+
rootCmd.AddCommand(sqlCmd())
8283

8384
return rootCmd.Execute()
8485
}

cmd/sql.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package cmd
4+
5+
import (
6+
"encoding/json"
7+
"fmt"
8+
"os"
9+
10+
"github.com/spf13/cobra"
11+
"github.com/xataio/pgroll/pkg/sql2pgroll"
12+
)
13+
14+
func sqlCmd() *cobra.Command {
15+
sqlCmd := &cobra.Command{
16+
Use: "sql <sql statement>",
17+
Short: "Convert SQL statements to pgroll operations",
18+
Args: cobra.ExactArgs(1),
19+
Hidden: true,
20+
RunE: func(cmd *cobra.Command, args []string) error {
21+
sql := args[0]
22+
23+
ops, err := sql2pgroll.Convert(sql)
24+
if err != nil {
25+
return fmt.Errorf("failed to convert SQL statement: %w", err)
26+
}
27+
28+
enc := json.NewEncoder(os.Stdout)
29+
enc.SetIndent("", " ")
30+
if err := enc.Encode(ops); err != nil {
31+
return fmt.Errorf("failed to encode operations: %w", err)
32+
}
33+
34+
return nil
35+
},
36+
}
37+
38+
return sqlCmd
39+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ require (
5959
github.com/opencontainers/go-digest v1.0.0 // indirect
6060
github.com/opencontainers/image-spec v1.1.0 // indirect
6161
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
62+
github.com/pganalyze/pg_query_go/v6 v6.0.0 // indirect
6263
github.com/pkg/errors v0.9.1 // indirect
6364
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
6465
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
6969
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
7070
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
7171
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
72+
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
73+
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
7274
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
7375
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
7476
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -140,6 +142,8 @@ github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQ
140142
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
141143
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
142144
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
145+
github.com/pganalyze/pg_query_go/v6 v6.0.0 h1:in6RkR/apfqlAtvqgDxd4Y4o87a5Pr8fkKDB4DrDo2c=
146+
github.com/pganalyze/pg_query_go/v6 v6.0.0/go.mod h1:nvTHIuoud6e1SfrUaFwHqT0i4b5Nr+1rPWVds3B5+50=
143147
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
144148
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
145149
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -325,6 +329,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf h1:
325329
google.golang.org/genproto/googleapis/rpc v0.0.0-20240730163845-b1a4ccb954bf/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
326330
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
327331
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
332+
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
333+
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
328334
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
329335
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
330336
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

pkg/sql2pgroll/alter_table.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package sql2pgroll
4+
5+
import (
6+
pgq "github.com/pganalyze/pg_query_go/v6"
7+
"github.com/xataio/pgroll/pkg/migrations"
8+
)
9+
10+
const PlaceHolderSQL = "TODO: Implement SQL data migration"
11+
12+
// convertAlterTableStmt converts an ALTER TABLE statement to pgroll operations.
13+
func convertAlterTableStmt(stmt *pgq.AlterTableStmt) (migrations.Operations, error) {
14+
if stmt.Objtype != pgq.ObjectType_OBJECT_TABLE {
15+
return nil, nil
16+
}
17+
18+
var ops migrations.Operations
19+
for _, cmd := range stmt.Cmds {
20+
alterTableCmd := cmd.GetAlterTableCmd()
21+
if alterTableCmd == nil {
22+
continue
23+
}
24+
25+
//nolint:gocritic
26+
switch alterTableCmd.Subtype {
27+
case pgq.AlterTableType_AT_SetNotNull:
28+
ops = append(ops, convertAlterTableSetNotNull(stmt, alterTableCmd))
29+
}
30+
}
31+
32+
return ops, nil
33+
}
34+
35+
func convertAlterTableSetNotNull(stmt *pgq.AlterTableStmt, cmd *pgq.AlterTableCmd) migrations.Operation {
36+
return &migrations.OpAlterColumn{
37+
Table: stmt.GetRelation().GetRelname(),
38+
Column: cmd.GetName(),
39+
Nullable: ptr(false),
40+
Up: PlaceHolderSQL,
41+
Down: PlaceHolderSQL,
42+
}
43+
}
44+
45+
func ptr[T any](x T) *T {
46+
return &x
47+
}

pkg/sql2pgroll/alter_table_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package sql2pgroll_test
4+
5+
import (
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
"github.com/xataio/pgroll/pkg/migrations"
11+
"github.com/xataio/pgroll/pkg/sql2pgroll"
12+
"github.com/xataio/pgroll/pkg/sql2pgroll/expect"
13+
)
14+
15+
func TestConvertAlterTableStatements(t *testing.T) {
16+
t.Parallel()
17+
18+
tests := []struct {
19+
sql string
20+
expectedOp migrations.Operation
21+
}{
22+
{
23+
sql: "ALTER TABLE foo ALTER COLUMN a SET NOT NULL",
24+
expectedOp: expect.AlterTableOp1,
25+
},
26+
}
27+
28+
for _, tc := range tests {
29+
t.Run(tc.sql, func(t *testing.T) {
30+
ops, err := sql2pgroll.Convert(tc.sql)
31+
require.NoError(t, err)
32+
33+
require.Len(t, ops, 1)
34+
35+
alterColumnOps, ok := ops[0].(*migrations.OpAlterColumn)
36+
require.True(t, ok)
37+
38+
assert.Equal(t, tc.expectedOp, alterColumnOps)
39+
})
40+
}
41+
}

pkg/sql2pgroll/convert.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package sql2pgroll
4+
5+
import (
6+
"fmt"
7+
8+
pgq "github.com/pganalyze/pg_query_go/v6"
9+
"github.com/xataio/pgroll/pkg/migrations"
10+
)
11+
12+
var ErrStatementCount = fmt.Errorf("expected exactly one statement")
13+
14+
// Convert converts a SQL statement to a slice of pgroll operations.
15+
func Convert(sql string) (migrations.Operations, error) {
16+
ops, err := convert(sql)
17+
if err != nil {
18+
return nil, err
19+
}
20+
21+
if ops == nil {
22+
return makeRawSQLOperation(sql), nil
23+
}
24+
25+
return ops, nil
26+
}
27+
28+
func convert(sql string) (migrations.Operations, error) {
29+
tree, err := pgq.Parse(sql)
30+
if err != nil {
31+
return nil, fmt.Errorf("parse error: %w", err)
32+
}
33+
34+
stmts := tree.GetStmts()
35+
if len(stmts) != 1 {
36+
return nil, fmt.Errorf("%w: got %d statements", ErrStatementCount, len(stmts))
37+
}
38+
node := stmts[0].GetStmt().GetNode()
39+
40+
switch node := (node).(type) {
41+
case *pgq.Node_CreateStmt:
42+
return convertCreateStmt(node.CreateStmt)
43+
case *pgq.Node_AlterTableStmt:
44+
return convertAlterTableStmt(node.AlterTableStmt)
45+
default:
46+
return makeRawSQLOperation(sql), nil
47+
}
48+
}
49+
50+
func makeRawSQLOperation(sql string) migrations.Operations {
51+
return migrations.Operations{
52+
&migrations.OpRawSQL{Up: sql},
53+
}
54+
}

pkg/sql2pgroll/create_table.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package sql2pgroll
4+
5+
import (
6+
"fmt"
7+
"strings"
8+
9+
pgq "github.com/pganalyze/pg_query_go/v6"
10+
"github.com/xataio/pgroll/pkg/migrations"
11+
)
12+
13+
// convertCreateStmt converts a CREATE TABLE statement to a pgroll operation.
14+
func convertCreateStmt(stmt *pgq.CreateStmt) ([]migrations.Operation, error) {
15+
columns := make([]migrations.Column, 0, len(stmt.TableElts))
16+
for _, elt := range stmt.TableElts {
17+
columns = append(columns, convertColumnDef(elt.GetColumnDef()))
18+
}
19+
20+
return migrations.Operations{
21+
&migrations.OpCreateTable{
22+
Name: stmt.Relation.Relname,
23+
Columns: columns,
24+
},
25+
}, nil
26+
}
27+
28+
func convertColumnDef(col *pgq.ColumnDef) migrations.Column {
29+
ignoredTypeParts := map[string]bool{
30+
"pg_catalog": true,
31+
}
32+
33+
// Build the type name, including any schema qualifiers
34+
typeParts := make([]string, 0, len(col.GetTypeName().Names))
35+
for _, node := range col.GetTypeName().Names {
36+
typePart := node.GetString_().GetSval()
37+
if _, ok := ignoredTypeParts[typePart]; ok {
38+
continue
39+
}
40+
typeParts = append(typeParts, typePart)
41+
}
42+
43+
// Build the type modifiers, such as precision and scale for numeric types
44+
var typeMods []string
45+
for _, node := range col.GetTypeName().Typmods {
46+
if x, ok := node.GetAConst().Val.(*pgq.A_Const_Ival); ok {
47+
typeMods = append(typeMods, fmt.Sprintf("%d", x.Ival.GetIval()))
48+
}
49+
}
50+
var typeModifier string
51+
if len(typeMods) > 0 {
52+
typeModifier = fmt.Sprintf("(%s)", strings.Join(typeMods, ","))
53+
}
54+
55+
// Build the array bounds for array types
56+
var arrayBounds string
57+
for _, node := range col.GetTypeName().ArrayBounds {
58+
bound := node.GetInteger().GetIval()
59+
if bound == -1 {
60+
arrayBounds = "[]"
61+
} else {
62+
arrayBounds = fmt.Sprintf("%s[%d]", arrayBounds, bound)
63+
}
64+
}
65+
66+
// Determine column nullability, uniqueness, and primary key status
67+
var notNull, unique, pk bool
68+
var defaultValue *string
69+
for _, constraint := range col.Constraints {
70+
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_NOTNULL {
71+
notNull = true
72+
}
73+
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_UNIQUE {
74+
unique = true
75+
}
76+
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_PRIMARY {
77+
pk = true
78+
notNull = true
79+
}
80+
}
81+
82+
return migrations.Column{
83+
Name: col.Colname,
84+
Type: strings.Join(typeParts, ".") + typeModifier + arrayBounds,
85+
Nullable: !notNull,
86+
Unique: unique,
87+
Default: defaultValue,
88+
Pk: pk,
89+
}
90+
}

pkg/sql2pgroll/create_table_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package sql2pgroll_test
4+
5+
import (
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
"github.com/xataio/pgroll/pkg/migrations"
11+
"github.com/xataio/pgroll/pkg/sql2pgroll"
12+
"github.com/xataio/pgroll/pkg/sql2pgroll/expect"
13+
)
14+
15+
func TestConvertCreateTableStatements(t *testing.T) {
16+
t.Parallel()
17+
18+
tests := []struct {
19+
sql string
20+
expectedOp migrations.Operation
21+
}{
22+
{
23+
sql: "CREATE TABLE foo(a int)",
24+
expectedOp: expect.CreateTableOp1,
25+
},
26+
{
27+
sql: "CREATE TABLE foo(a int NOT NULL)",
28+
expectedOp: expect.CreateTableOp2,
29+
},
30+
{
31+
sql: "CREATE TABLE foo(a varchar(255))",
32+
expectedOp: expect.CreateTableOp3,
33+
},
34+
{
35+
sql: "CREATE TABLE foo(a numeric(10, 2))",
36+
expectedOp: expect.CreateTableOp4,
37+
},
38+
{
39+
sql: "CREATE TABLE foo(a int UNIQUE)",
40+
expectedOp: expect.CreateTableOp5,
41+
},
42+
{
43+
sql: "CREATE TABLE foo(a int PRIMARY KEY)",
44+
expectedOp: expect.CreateTableOp6,
45+
},
46+
{
47+
sql: "CREATE TABLE foo(a text[])",
48+
expectedOp: expect.CreateTableOp7,
49+
},
50+
{
51+
sql: "CREATE TABLE foo(a text[5])",
52+
expectedOp: expect.CreateTableOp8,
53+
},
54+
{
55+
sql: "CREATE TABLE foo(a text[5][3])",
56+
expectedOp: expect.CreateTableOp9,
57+
},
58+
}
59+
60+
for _, tc := range tests {
61+
t.Run(tc.sql, func(t *testing.T) {
62+
ops, err := sql2pgroll.Convert(tc.sql)
63+
require.NoError(t, err)
64+
65+
require.Len(t, ops, 1)
66+
67+
createTableOp, ok := ops[0].(*migrations.OpCreateTable)
68+
require.True(t, ok)
69+
70+
assert.Equal(t, tc.expectedOp, createTableOp)
71+
})
72+
}
73+
}

0 commit comments

Comments
 (0)