Skip to content

Commit 51b2641

Browse files
authored
Convert DROP TABLE SQL to pgroll operation (#529)
Converts `DROP TABLE` statements in these forms: ```sql DROP TABLE foo DROP TABLE foo RESTRICT DROP TABLE foo.bar DROP TABLE IF EXISTS foo ``` These forms fall back to raw SQL: ```sql DROP TABLE foo CASCADE ``` Part of #504
1 parent ef07810 commit 51b2641

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

pkg/sql2pgroll/drop.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ import (
1212

1313
// convertDropStatement converts supported drop statements to pgroll operations
1414
func convertDropStatement(stmt *pgq.DropStmt) (migrations.Operations, error) {
15-
if stmt.RemoveType == pgq.ObjectType_OBJECT_INDEX {
15+
switch stmt.RemoveType {
16+
case pgq.ObjectType_OBJECT_INDEX:
1617
return convertDropIndexStatement(stmt)
18+
case pgq.ObjectType_OBJECT_TABLE:
19+
return convertDropTableStatement(stmt)
20+
1721
}
1822
return nil, nil
1923
}
@@ -46,3 +50,27 @@ func canConvertDropIndex(stmt *pgq.DropStmt) bool {
4650
}
4751
return true
4852
}
53+
54+
// convertDropTableStatement converts simple DROP TABLE statements to pgroll operations
55+
func convertDropTableStatement(stmt *pgq.DropStmt) (migrations.Operations, error) {
56+
if !canConvertDropTable(stmt) {
57+
return nil, nil
58+
}
59+
60+
items := stmt.GetObjects()[0].GetList().GetItems()
61+
parts := make([]string, len(items))
62+
for i, item := range items {
63+
parts[i] = item.GetString_().GetSval()
64+
}
65+
66+
return migrations.Operations{
67+
&migrations.OpDropTable{
68+
Name: strings.Join(parts, "."),
69+
},
70+
}, nil
71+
}
72+
73+
// canConvertDropTable checks whether we can convert the statement without losing any information.
74+
func canConvertDropTable(stmt *pgq.DropStmt) bool {
75+
return stmt.Behavior != pgq.DropBehavior_DROP_CASCADE
76+
}

pkg/sql2pgroll/drop_test.go

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,52 @@ func TestDropIndexStatements(t *testing.T) {
5454
}
5555
}
5656

57-
func TestUnconvertableDropIndexStatements(t *testing.T) {
57+
func TestDropTableStatements(t *testing.T) {
58+
t.Parallel()
59+
60+
tests := []struct {
61+
sql string
62+
expectedOp migrations.Operation
63+
}{
64+
{
65+
sql: "DROP TABLE foo",
66+
expectedOp: expect.DropTableOp1,
67+
},
68+
{
69+
sql: "DROP TABLE foo RESTRICT",
70+
expectedOp: expect.DropTableOp1,
71+
},
72+
{
73+
sql: "DROP TABLE IF EXISTS foo",
74+
expectedOp: expect.DropTableOp1,
75+
},
76+
{
77+
sql: "DROP TABLE foo.bar",
78+
expectedOp: expect.DropTableOp2,
79+
},
80+
}
81+
82+
for _, tc := range tests {
83+
t.Run(tc.sql, func(t *testing.T) {
84+
ops, err := sql2pgroll.Convert(tc.sql)
85+
require.NoError(t, err)
86+
87+
require.Len(t, ops, 1)
88+
89+
assert.Equal(t, tc.expectedOp, ops[0])
90+
})
91+
}
92+
}
93+
94+
func TestUnconvertableDropStatements(t *testing.T) {
5895
t.Parallel()
5996

6097
tests := []string{
98+
// Drop index
6199
"DROP INDEX foo CASCADE",
100+
101+
// Drop table
102+
"DROP TABLE foo CASCADE",
62103
}
63104

64105
for _, sql := range tests {

pkg/sql2pgroll/expect/drop_table.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package expect
4+
5+
import (
6+
"github.com/xataio/pgroll/pkg/migrations"
7+
)
8+
9+
var DropTableOp1 = &migrations.OpDropTable{
10+
Name: "foo",
11+
}
12+
13+
var DropTableOp2 = &migrations.OpDropTable{
14+
Name: "foo.bar",
15+
}

0 commit comments

Comments
 (0)