Skip to content

Commit 9922e5f

Browse files
authored
Validate that FK constraint name is unique (#428)
We can do this in our validation step before even touching the database. Part of #105
1 parent d00dd1e commit 9922e5f

File tree

3 files changed

+101
-6
lines changed

3 files changed

+101
-6
lines changed

pkg/migrations/op_set_fk.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"strings"
99

1010
"github.com/lib/pq"
11+
1112
"github.com/xataio/pgroll/pkg/db"
1213
"github.com/xataio/pgroll/pkg/schema"
1314
)
@@ -58,6 +59,18 @@ func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error
5859
}
5960
}
6061

62+
table := s.GetTable(o.Table)
63+
if table == nil {
64+
return TableDoesNotExistError{Name: o.Table}
65+
}
66+
67+
if table.ConstraintExists(o.References.Name) {
68+
return ConstraintAlreadyExistsError{
69+
Table: table.Name,
70+
Constraint: o.References.Name,
71+
}
72+
}
73+
6174
if o.Up == "" {
6275
return FieldRequiredError{Name: "up"}
6376
}

pkg/migrations/op_set_fk_test.go

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ import (
66
"database/sql"
77
"testing"
88

9-
"github.com/xataio/pgroll/internal/testutils"
10-
119
"github.com/stretchr/testify/assert"
10+
11+
"github.com/xataio/pgroll/internal/testutils"
1212
"github.com/xataio/pgroll/pkg/migrations"
1313
)
1414

@@ -1120,6 +1120,88 @@ func TestSetForeignKey(t *testing.T) {
11201120
ColumnMustHaveComment(t, db, schema, "posts", "user_id", "the id of the author")
11211121
},
11221122
},
1123+
{
1124+
name: "validate that foreign key name is unique",
1125+
migrations: []migrations.Migration{
1126+
{
1127+
Name: "01_add_tables",
1128+
Operations: migrations.Operations{
1129+
&migrations.OpCreateTable{
1130+
Name: "users",
1131+
Columns: []migrations.Column{
1132+
{
1133+
Name: "id",
1134+
Type: "serial",
1135+
Pk: ptr(true),
1136+
},
1137+
{
1138+
Name: "name",
1139+
Type: "text",
1140+
},
1141+
},
1142+
},
1143+
&migrations.OpCreateTable{
1144+
Name: "posts",
1145+
Columns: []migrations.Column{
1146+
{
1147+
Name: "id",
1148+
Type: "serial",
1149+
Pk: ptr(true),
1150+
},
1151+
{
1152+
Name: "title",
1153+
Type: "text",
1154+
},
1155+
{
1156+
Name: "user_id",
1157+
Type: "integer",
1158+
Nullable: ptr(true),
1159+
},
1160+
},
1161+
},
1162+
},
1163+
},
1164+
{
1165+
Name: "02_add_fk_constraint",
1166+
Operations: migrations.Operations{
1167+
&migrations.OpAlterColumn{
1168+
Table: "posts",
1169+
Column: "user_id",
1170+
References: &migrations.ForeignKeyReference{
1171+
Name: "fk_users_id",
1172+
Table: "users",
1173+
Column: "id",
1174+
},
1175+
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
1176+
Down: "user_id",
1177+
},
1178+
},
1179+
},
1180+
{
1181+
Name: "03_add_fk_constraint_again",
1182+
Operations: migrations.Operations{
1183+
&migrations.OpAlterColumn{
1184+
Table: "posts",
1185+
Column: "user_id",
1186+
References: &migrations.ForeignKeyReference{
1187+
Name: "fk_users_id",
1188+
Table: "users",
1189+
Column: "id",
1190+
},
1191+
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
1192+
Down: "user_id",
1193+
},
1194+
},
1195+
},
1196+
},
1197+
wantStartErr: migrations.ConstraintAlreadyExistsError{
1198+
Table: "posts",
1199+
Constraint: "fk_users_id",
1200+
},
1201+
afterStart: func(t *testing.T, db *sql.DB, schema string) {},
1202+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {},
1203+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {},
1204+
},
11231205
})
11241206
}
11251207

pkg/roll/execute.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func (m *Roll) Complete(ctx context.Context) error {
155155
}
156156

157157
// read the current schema
158-
schema, err := m.state.ReadSchema(ctx, m.schema)
158+
currentSchema, err := m.state.ReadSchema(ctx, m.schema)
159159
if err != nil {
160160
return fmt.Errorf("unable to read schema: %w", err)
161161
}
@@ -175,7 +175,7 @@ func (m *Roll) Complete(ctx context.Context) error {
175175
// execute operations
176176
refreshViews := false
177177
for _, op := range migration.Operations {
178-
err := op.Complete(ctx, m.pgConn, m.sqlTransformer, schema)
178+
err := op.Complete(ctx, m.pgConn, m.sqlTransformer, currentSchema)
179179
if err != nil {
180180
return fmt.Errorf("unable to execute complete operation: %w", err)
181181
}
@@ -189,12 +189,12 @@ func (m *Roll) Complete(ctx context.Context) error {
189189

190190
// recreate views for the new version (if some operations require it, ie SQL)
191191
if refreshViews && !m.disableVersionSchemas {
192-
schema, err = m.state.ReadSchema(ctx, m.schema)
192+
currentSchema, err = m.state.ReadSchema(ctx, m.schema)
193193
if err != nil {
194194
return fmt.Errorf("unable to read schema: %w", err)
195195
}
196196

197-
err = m.ensureViews(ctx, schema, migration.Name)
197+
err = m.ensureViews(ctx, currentSchema, migration.Name)
198198
if err != nil {
199199
return err
200200
}

0 commit comments

Comments
 (0)