Skip to content

Commit bc24f9e

Browse files
authored
Support parsing multiple SQL statements in sql2pgroll.Convert (#690)
This PR introduces support for translating multiple SQL statements to pgroll operations in `sql2pgroll.Convert`. This PR required by the upcoming improvements to support reading complete SQL migration files produced by ORMs.
1 parent 331c25f commit bc24f9e

File tree

2 files changed

+285
-35
lines changed

2 files changed

+285
-35
lines changed

pkg/sql2pgroll/convert.go

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,52 +10,57 @@ import (
1010
"github.com/xataio/pgroll/pkg/migrations"
1111
)
1212

13-
var ErrStatementCount = fmt.Errorf("expected exactly one statement")
14-
1513
// Convert converts a SQL statement to a slice of pgroll operations.
1614
func Convert(sql string) (migrations.Operations, error) {
17-
ops, err := convert(sql)
18-
if err != nil {
19-
return nil, err
20-
}
21-
22-
if ops == nil {
23-
return makeRawSQLOperation(sql), nil
24-
}
25-
26-
return ops, nil
27-
}
28-
29-
func convert(sql string) (migrations.Operations, error) {
3015
tree, err := pgq.Parse(sql)
3116
if err != nil {
3217
return nil, fmt.Errorf("parse error: %w", err)
3318
}
3419

20+
var migOps migrations.Operations
3521
stmts := tree.GetStmts()
36-
if len(stmts) != 1 {
37-
return nil, fmt.Errorf("%w: got %d statements", ErrStatementCount, len(stmts))
38-
}
39-
node := stmts[0].GetStmt().GetNode()
40-
41-
switch node := (node).(type) {
42-
case *pgq.Node_CreateStmt:
43-
return convertCreateStmt(node.CreateStmt)
44-
case *pgq.Node_AlterTableStmt:
45-
return convertAlterTableStmt(node.AlterTableStmt)
46-
case *pgq.Node_RenameStmt:
47-
return convertRenameStmt(node.RenameStmt)
48-
case *pgq.Node_DropStmt:
49-
return convertDropStatement(node.DropStmt)
50-
case *pgq.Node_IndexStmt:
51-
return convertCreateIndexStmt(node.IndexStmt)
52-
default:
53-
return makeRawSQLOperation(sql), nil
22+
for i, stmt := range stmts {
23+
if stmt.GetStmt() == nil {
24+
continue
25+
}
26+
node := stmts[i].GetStmt().GetNode()
27+
var ops migrations.Operations
28+
var err error
29+
switch node := (node).(type) {
30+
case *pgq.Node_CreateStmt:
31+
ops, err = convertCreateStmt(node.CreateStmt)
32+
case *pgq.Node_AlterTableStmt:
33+
ops, err = convertAlterTableStmt(node.AlterTableStmt)
34+
case *pgq.Node_RenameStmt:
35+
ops, err = convertRenameStmt(node.RenameStmt)
36+
case *pgq.Node_DropStmt:
37+
ops, err = convertDropStatement(node.DropStmt)
38+
case *pgq.Node_IndexStmt:
39+
ops, err = convertCreateIndexStmt(node.IndexStmt)
40+
default:
41+
// SQL statement cannot be transformed to pgroll operation
42+
// so we will use raw SQL operation
43+
ops = makeRawSQLOperation(sql, i)
44+
}
45+
if err != nil {
46+
return nil, err
47+
}
48+
if ops == nil {
49+
ops = makeRawSQLOperation(sql, i)
50+
}
51+
migOps = append(migOps, ops...)
5452
}
53+
return migOps, nil
5554
}
5655

57-
func makeRawSQLOperation(sql string) migrations.Operations {
56+
func makeRawSQLOperation(sql string, idx int) migrations.Operations {
57+
stmts, err := pgq.SplitWithParser(sql, true)
58+
if err != nil {
59+
return migrations.Operations{
60+
&migrations.OpRawSQL{Up: sql},
61+
}
62+
}
5863
return migrations.Operations{
59-
&migrations.OpRawSQL{Up: sql},
64+
&migrations.OpRawSQL{Up: stmts[idx]},
6065
}
6166
}

pkg/sql2pgroll/convert_test.go

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package sql2pgroll_test
4+
5+
import (
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
10+
"github.com/xataio/pgroll/pkg/migrations"
11+
"github.com/xataio/pgroll/pkg/sql2pgroll"
12+
)
13+
14+
func TestConvertToMigration(t *testing.T) {
15+
tests := map[string]struct {
16+
sql string
17+
expectedOps migrations.Operations
18+
expectedErr bool
19+
}{
20+
"empty SQL statement": {
21+
sql: "",
22+
expectedOps: nil,
23+
expectedErr: false,
24+
},
25+
"single SQL statement": {
26+
sql: "DROP TYPE t1;",
27+
expectedOps: migrations.Operations{
28+
&migrations.OpRawSQL{
29+
Up: "DROP TYPE t1",
30+
},
31+
},
32+
expectedErr: false,
33+
},
34+
"single multiline statement with comments": {
35+
sql: `CREATE TABLE t1 (
36+
id INT, -- my id column
37+
name TEXT -- my name column
38+
);
39+
`,
40+
expectedOps: migrations.Operations{
41+
&migrations.OpCreateTable{
42+
Name: "t1",
43+
Columns: []migrations.Column{
44+
{
45+
Name: "id",
46+
Type: "int",
47+
Nullable: true,
48+
},
49+
{
50+
Name: "name",
51+
Type: "text",
52+
Nullable: true,
53+
},
54+
},
55+
},
56+
},
57+
expectedErr: false,
58+
},
59+
"single function definition multiline with comments": {
60+
sql: `CREATE OR REPLACE FUNCTION check_password(uname TEXT, pass TEXT)
61+
RETURNS BOOLEAN AS $$ -- check password for username
62+
DECLARE passed BOOLEAN;
63+
BEGIN
64+
SELECT (pwd = $2) INTO passed
65+
FROM pwds -- from passwords table
66+
WHERE username = $1; -- select password for username
67+
RETURN passed;
68+
END; $$ LANGUAGE plpgsql
69+
SECURITY DEFINER`,
70+
expectedOps: migrations.Operations{
71+
&migrations.OpRawSQL{
72+
Up: `CREATE OR REPLACE FUNCTION check_password(uname TEXT, pass TEXT)
73+
RETURNS BOOLEAN AS $$ -- check password for username
74+
DECLARE passed BOOLEAN;
75+
BEGIN
76+
SELECT (pwd = $2) INTO passed
77+
FROM pwds -- from passwords table
78+
WHERE username = $1; -- select password for username
79+
RETURN passed;
80+
END; $$ LANGUAGE plpgsql
81+
SECURITY DEFINER`,
82+
},
83+
},
84+
},
85+
"multiple SQL raw migration statements": {
86+
sql: "DROP TYPE t1; DROP TYPE t2;",
87+
expectedOps: migrations.Operations{
88+
&migrations.OpRawSQL{
89+
Up: "DROP TYPE t1",
90+
},
91+
&migrations.OpRawSQL{
92+
Up: "DROP TYPE t2",
93+
},
94+
},
95+
expectedErr: false,
96+
},
97+
"multiple SQL migrations to raw and regular pgroll operations": {
98+
sql: "CREATE TABLE t1 (id INT); DROP INDEX idx1; DROP TYPE t1; ALTER TABLE t1 ADD COLUMN name TEXT;",
99+
expectedOps: migrations.Operations{
100+
&migrations.OpCreateTable{
101+
Name: "t1",
102+
Columns: []migrations.Column{
103+
{
104+
Name: "id",
105+
Type: "int",
106+
Nullable: true,
107+
},
108+
},
109+
},
110+
&migrations.OpDropIndex{
111+
Name: "idx1",
112+
},
113+
&migrations.OpRawSQL{
114+
Up: "DROP TYPE t1",
115+
},
116+
&migrations.OpAddColumn{
117+
Table: "t1",
118+
Column: migrations.Column{
119+
Name: "name",
120+
Type: "text",
121+
Nullable: true,
122+
},
123+
Up: sql2pgroll.PlaceHolderSQL,
124+
},
125+
},
126+
expectedErr: false,
127+
},
128+
"multiple unknown DDL statements": {
129+
sql: "CREATE TYPE t1 AS ENUM ('a', 'b'); CREATE DOMAIN d1 AS TEXT; CREATE SCHEMA s1; CREATE EXTENSION e1;",
130+
expectedOps: migrations.Operations{
131+
&migrations.OpRawSQL{
132+
Up: "CREATE TYPE t1 AS ENUM ('a', 'b')",
133+
},
134+
&migrations.OpRawSQL{
135+
Up: "CREATE DOMAIN d1 AS TEXT",
136+
},
137+
&migrations.OpRawSQL{
138+
Up: "CREATE SCHEMA s1",
139+
},
140+
&migrations.OpRawSQL{
141+
Up: "CREATE EXTENSION e1",
142+
},
143+
},
144+
expectedErr: false,
145+
},
146+
"multiple empty SQL statements": {
147+
sql: ";;",
148+
},
149+
"multiple statements with empty SQL statement": {
150+
sql: "CREATE TABLE t1 (id INT);; DROP TYPE t1;;",
151+
expectedOps: migrations.Operations{
152+
&migrations.OpCreateTable{
153+
Name: "t1",
154+
Columns: []migrations.Column{
155+
{
156+
Name: "id",
157+
Type: "int",
158+
Nullable: true,
159+
},
160+
},
161+
},
162+
&migrations.OpRawSQL{
163+
Up: "DROP TYPE t1",
164+
},
165+
},
166+
expectedErr: false,
167+
},
168+
"multiple multiline statments with comments": {
169+
sql: `DROP TYPE t1; -- drop type t1
170+
DROP INDEX ixd1; -- drop my index
171+
`,
172+
expectedOps: migrations.Operations{
173+
&migrations.OpRawSQL{
174+
Up: "DROP TYPE t1",
175+
},
176+
&migrations.OpDropIndex{
177+
Name: "ixd1",
178+
},
179+
},
180+
expectedErr: false,
181+
},
182+
"multiple statements with function definition multiline with comments": {
183+
sql: `DROP TABLE t1; DROP INDEX idx2; CREATE OR REPLACE FUNCTION check_password(uname TEXT, pass TEXT)
184+
RETURNS BOOLEAN AS $$ -- check password for username
185+
DECLARE passed BOOLEAN;
186+
BEGIN
187+
SELECT (pwd = $2) INTO passed
188+
FROM pwds -- from passwords table
189+
WHERE username = $1; -- select password for username
190+
RETURN passed;
191+
END; $$ LANGUAGE plpgsql
192+
SECURITY DEFINER;
193+
CREATE INDEX idx1 ON t1 (id);
194+
CREATE TYPE t1;`,
195+
expectedOps: migrations.Operations{
196+
&migrations.OpDropTable{
197+
Name: "t1",
198+
},
199+
&migrations.OpDropIndex{
200+
Name: "idx2",
201+
},
202+
&migrations.OpRawSQL{
203+
Up: `CREATE OR REPLACE FUNCTION check_password(uname TEXT, pass TEXT)
204+
RETURNS BOOLEAN AS $$ -- check password for username
205+
DECLARE passed BOOLEAN;
206+
BEGIN
207+
SELECT (pwd = $2) INTO passed
208+
FROM pwds -- from passwords table
209+
WHERE username = $1; -- select password for username
210+
RETURN passed;
211+
END; $$ LANGUAGE plpgsql
212+
SECURITY DEFINER`,
213+
},
214+
&migrations.OpCreateIndex{
215+
Name: "idx1",
216+
Table: "t1",
217+
Columns: []string{"id"},
218+
Method: "btree",
219+
},
220+
&migrations.OpRawSQL{
221+
Up: "CREATE TYPE t1",
222+
},
223+
},
224+
expectedErr: false,
225+
},
226+
"syntax error in second statement": {
227+
sql: "DROP INDEX idx1; DROP INDX idx2",
228+
expectedOps: nil,
229+
expectedErr: true,
230+
},
231+
}
232+
233+
for name, tc := range tests {
234+
t.Run(name, func(t *testing.T) {
235+
ops, err := sql2pgroll.Convert(tc.sql)
236+
if tc.expectedErr {
237+
assert.NotNil(t, err)
238+
} else {
239+
assert.Nil(t, err)
240+
}
241+
assert.Len(t, ops, len(tc.expectedOps))
242+
assert.Equal(t, tc.expectedOps, ops)
243+
})
244+
}
245+
}

0 commit comments

Comments
 (0)