Skip to content

Commit 2051e2f

Browse files
Add exclude schemas
1 parent 694c00c commit 2051e2f

File tree

11 files changed

+312
-63
lines changed

11 files changed

+312
-63
lines changed

internal/migration_acceptance_tests/acceptance_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expe
130130
return diff.Generate(ctx, connPool, diff.DDLSchemaSource(newSchemaDDL),
131131
append(planOpts,
132132
diff.WithBetaDoNotCallWithAllowCustomSchemaOpts(),
133-
diff.WithSchemas("public"),
134133
diff.WithTempDbFactory(tempDbFactory),
135134
)...)
136135
}

internal/schema/filters.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ func schemaNameFilter(schema string) nameFilter {
1010
}
1111
}
1212

13+
func notSchemaNameFilter(schema string) nameFilter {
14+
return func(obj SchemaQualifiedName) bool {
15+
return obj.SchemaName != schema
16+
}
17+
}
18+
1319
func orNameFilter(filters ...nameFilter) nameFilter {
1420
return func(obj SchemaQualifiedName) bool {
1521
for _, filter := range filters {
@@ -21,6 +27,21 @@ func orNameFilter(filters ...nameFilter) nameFilter {
2127
}
2228
}
2329

30+
func andNameFilter(filters ...nameFilter) nameFilter {
31+
return func(obj SchemaQualifiedName) bool {
32+
if len(filters) == 0 {
33+
return false
34+
}
35+
36+
for _, filter := range filters {
37+
if !filter(obj) {
38+
return false
39+
}
40+
}
41+
return true
42+
}
43+
}
44+
2445
func filterSliceByName[T any](objs []T, getNameFn func(T) SchemaQualifiedName, filter nameFilter) []T {
2546
var filteredObjs []T
2647
for _, obj := range objs {

internal/schema/filters_test.go

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,70 @@ func (f fakeNameFilter) filter(input SchemaQualifiedName) bool {
3030
}
3131

3232
func TestOrNameFilters(t *testing.T) {
33+
someName1 := SchemaQualifiedName{
34+
SchemaName: "some_schema",
35+
EscapedName: "some_name",
36+
}
37+
for _, tc := range []struct {
38+
name string
39+
input SchemaQualifiedName
40+
filters []fakeNameFilterMock
41+
expectedOut bool
42+
}{
43+
{
44+
name: "empty",
45+
input: someName1,
46+
expectedOut: false,
47+
},
48+
{
49+
name: "one filter (true)",
50+
input: someName1,
51+
filters: []fakeNameFilterMock{
52+
{
53+
expectedInput: someName1,
54+
returnValue: true,
55+
},
56+
},
57+
expectedOut: true,
58+
},
59+
{
60+
name: "one filter (false)",
61+
input: someName1,
62+
filters: []fakeNameFilterMock{
63+
{expectedInput: someName1, returnValue: false},
64+
},
65+
expectedOut: false,
66+
},
67+
{
68+
name: "two filters (false, true)",
69+
input: someName1,
70+
filters: []fakeNameFilterMock{
71+
{expectedInput: someName1, returnValue: false},
72+
{expectedInput: someName1, returnValue: true},
73+
},
74+
expectedOut: false,
75+
},
76+
{
77+
name: "two filters (true, true)",
78+
input: someName1,
79+
filters: []fakeNameFilterMock{
80+
{expectedInput: someName1, returnValue: true},
81+
{expectedInput: someName1, returnValue: true},
82+
},
83+
expectedOut: true,
84+
},
85+
} {
86+
t.Run(tc.name, func(t *testing.T) {
87+
var filters []nameFilter
88+
for _, filter := range tc.filters {
89+
filters = append(filters, newFakeNameFilter(t, filter).filter)
90+
}
91+
assert.Equal(t, tc.expectedOut, andNameFilter(filters...)(tc.input))
92+
})
93+
}
94+
}
95+
96+
func TestAndNameFilters(t *testing.T) {
3397
someName1 := SchemaQualifiedName{
3498
SchemaName: "some_schema",
3599
EscapedName: "some_name",
@@ -91,5 +155,4 @@ func TestOrNameFilters(t *testing.T) {
91155
assert.Equal(t, tc.expectedOut, orNameFilter(filters...)(tc.input))
92156
})
93157
}
94-
95158
}

internal/schema/schema.go

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func (t Table) IsPartitioned() bool {
155155

156156
// IsPartition returns whether the table is a partition.
157157
// It represents a mismatch in modeling because the ForValues and ParentTable are stored separately.
158-
// Instead, they fields be stored under the same struct as a nilable pointer and this function should be deleted
158+
// Instead, the fields should be stored under the same struct as a nilable pointer, and this function should be deleted.
159159
func (t Table) IsPartition() bool {
160160
return t.ParentTable != nil
161161
}
@@ -346,21 +346,31 @@ type (
346346
GetSchemaOpt func(*getSchemaOptions)
347347
)
348348

349-
// WithSchemas filters the schema to only include the given schemas. This unions with any schemas that are already included
350-
// via WithSchemas.
351-
func WithSchemas(schemas ...string) GetSchemaOpt {
349+
// WithIncludeSchemas filters the schema to only include the given schemas. This unions with any schemas that are already included
350+
// via WithIncludeSchemas. If empty, then all schemas are included.
351+
func WithIncludeSchemas(schemas ...string) GetSchemaOpt {
352352
return func(o *getSchemaOptions) {
353353
for _, schema := range schemas {
354354
o.includeSchemas = append(o.includeSchemas, schema)
355355
}
356356
}
357357
}
358358

359+
// WithExcludeSchemas filters the schema to exclude the given schemas. This unions with any schemas that are already excluded
360+
// via WithExcludeSchemas. If empty, then no schemas are excluded.
361+
func WithExcludeSchemas(schemas ...string) GetSchemaOpt {
362+
return func(o *getSchemaOptions) {
363+
o.excludeSchemas = append(o.excludeSchemas, schemas...)
364+
}
365+
}
366+
359367
type getSchemaOptions struct {
360368
// includeSchemas is a list of schemas to include in the schema. If empty, then all schemas are included.
361369
// We could have built a more complex set of options using the nameFilter system (nested unions and intersections);
362370
// however, I felt it could expose some weird behaviors that we don't want to have to worry about just yet,
363371
includeSchemas []string
372+
// excludeSchemas is the exclude analog of includeSchemas.
373+
excludeSchemas []string
364374
}
365375

366376
// GetSchema fetches the database schema. It is a non-atomic operation.
@@ -382,34 +392,77 @@ func GetSchema(ctx context.Context, db queries.DBTX, opts ...GetSchemaOpt) (Sche
382392
opt(&options)
383393
}
384394

395+
nameFilter, err := buildNameFilter(options)
396+
if err != nil {
397+
return Schema{}, fmt.Errorf("building name filter: %w", err)
398+
}
399+
385400
return (&schemaFetcher{
386401
q: queries.New(db),
387402
goroutineRunnerFactory: goroutineRunnerFactory,
388-
nameFilter: buildNameFilter(options),
403+
nameFilter: nameFilter,
389404
}).getSchema(ctx)
390405
}
391406

392-
func buildNameFilter(options getSchemaOptions) nameFilter {
393-
if len(options.includeSchemas) == 0 {
407+
func buildNameFilter(options getSchemaOptions) (nameFilter, error) {
408+
if intersection := intersect(options.includeSchemas, options.excludeSchemas); len(intersection) > 0 {
409+
return nil, fmt.Errorf("schemas %v are both included and excluded", intersection)
410+
}
411+
412+
includeSchemasFilter := buildIncludeSchemasFilter(options.includeSchemas)
413+
excludeSchemasFilter := buildExcludeSchemasFilter(options.excludeSchemas)
414+
return andNameFilter(includeSchemasFilter, excludeSchemasFilter), nil
415+
}
416+
417+
func intersect(a, b []string) []string {
418+
inAByA := make(map[string]bool)
419+
for _, s := range a {
420+
inAByA[s] = true
421+
}
422+
intersection := make([]string, 0, len(b))
423+
for _, s := range b {
424+
if inAByA[s] {
425+
intersection = append(intersection, s)
426+
}
427+
}
428+
return intersection
429+
}
430+
431+
func buildIncludeSchemasFilter(schemas []string) nameFilter {
432+
if len(schemas) == 0 {
394433
return func(name SchemaQualifiedName) bool {
395434
return true
396435
}
397436
}
398437

399438
var filters []nameFilter
400-
for _, schema := range options.includeSchemas {
439+
for _, schema := range schemas {
401440
filters = append(filters, schemaNameFilter(schema))
402441
}
403442
return orNameFilter(filters...)
404443
}
405444

445+
func buildExcludeSchemasFilter(schemas []string) nameFilter {
446+
if len(schemas) == 0 {
447+
return func(name SchemaQualifiedName) bool {
448+
return true
449+
}
450+
}
451+
452+
var filters []nameFilter
453+
for _, schema := range schemas {
454+
filters = append(filters, notSchemaNameFilter(schema))
455+
}
456+
return andNameFilter(filters...)
457+
}
458+
406459
type (
407460
schemaFetcher struct {
408461
q *queries.Queries
409462
// goroutineRunnerFactory is a factory function that returns a GoroutineRunner. We need to be able to construct
410463
// multiple GoroutineRunners to avoid deadlock created by circular dependencies of submitted go routines.
411464
goroutineRunnerFactory func() concurrent.GoroutineRunner
412-
// nameFilter is a filter that determienes which schema objects to include in the schema via their
465+
// nameFilter is a filter that determines which schema objects to include in the schema via their
413466
// schema name and object name.
414467
//
415468
// Currently, we don't do any sort of validation to ensure that all dependencies are included, so users might

internal/schema/schema_test.go

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ import (
1616

1717
type testCase struct {
1818
name string
19-
ddl []string
2019
opts []GetSchemaOpt
20+
ddl []string
2121
expectedSchema Schema
2222
// expectedHash is the expected hash of the schema. If it is not provided, the test will not validate the hash.
23-
expectedHash string
24-
expectedErrIs error
23+
expectedHash string
24+
expectedErrIs error
25+
expectedErrContains string
2526
}
2627

2728
var (
@@ -38,7 +39,7 @@ var (
3839
{
3940
name: "Simple schema (validate all schema objects and schema name filters)",
4041
opts: []GetSchemaOpt{
41-
WithSchemas("public", "schema_1", "schema_2"),
42+
WithIncludeSchemas("public", "schema_1", "schema_2"),
4243
},
4344
ddl: []string{`
4445
CREATE SCHEMA schema_1;
@@ -796,7 +797,7 @@ var (
796797
},
797798
{
798799
name: "Filtering - filtering out the base table",
799-
opts: []GetSchemaOpt{WithSchemas("public")},
800+
opts: []GetSchemaOpt{WithIncludeSchemas("public")},
800801
ddl: []string{`
801802
CREATE SCHEMA schema_filtered_1;
802803
CREATE TABLE schema_filtered_1.foobar(
@@ -821,7 +822,7 @@ var (
821822
},
822823
{
823824
name: "Filtering - filtering out partition",
824-
opts: []GetSchemaOpt{WithSchemas("public")},
825+
opts: []GetSchemaOpt{WithIncludeSchemas("public")},
825826
ddl: []string{`
826827
CREATE SCHEMA schema_filtered_1;
827828
CREATE TABLE foobar(
@@ -901,6 +902,69 @@ var (
901902
},
902903
},
903904
},
905+
{
906+
name: "Filters - exclude schemas",
907+
opts: []GetSchemaOpt{
908+
WithExcludeSchemas("schema_1"),
909+
},
910+
ddl: []string{`
911+
CREATE TABLE foobar();
912+
CREATE SCHEMA schema_1;
913+
CREATE TABLE schema_1.foobar();
914+
CREATE SCHEMA schema_2;
915+
CREATE TABLE schema_2.foobar();
916+
`},
917+
expectedSchema: Schema{
918+
Tables: []Table{
919+
{
920+
SchemaQualifiedName: SchemaQualifiedName{SchemaName: "public", EscapedName: "\"foobar\""},
921+
ReplicaIdentity: ReplicaIdentityDefault,
922+
},
923+
{
924+
SchemaQualifiedName: SchemaQualifiedName{SchemaName: "schema_2", EscapedName: "\"foobar\""},
925+
ReplicaIdentity: ReplicaIdentityDefault,
926+
},
927+
},
928+
},
929+
},
930+
{
931+
name: "Filters - include and exclude schemas",
932+
opts: []GetSchemaOpt{
933+
WithIncludeSchemas("schema_1"),
934+
// schema_3 is inherently excluded since it is not included
935+
WithExcludeSchemas("schema_2"),
936+
},
937+
ddl: []string{`
938+
CREATE TABLE foobar();
939+
CREATE SCHEMA schema_1;
940+
CREATE TABLE schema_1.foobar();
941+
CREATE SCHEMA schema_2;
942+
CREATE TABLE schema_2.foobar();
943+
`},
944+
expectedSchema: Schema{
945+
Tables: []Table{
946+
{
947+
SchemaQualifiedName: SchemaQualifiedName{SchemaName: "schema_1", EscapedName: "\"foobar\""},
948+
ReplicaIdentity: ReplicaIdentityDefault,
949+
},
950+
},
951+
},
952+
},
953+
{
954+
name: "Filter - include and exclude a schema",
955+
opts: []GetSchemaOpt{
956+
WithIncludeSchemas("schema_1"),
957+
958+
WithIncludeSchemas("schema_2"),
959+
WithExcludeSchemas("schema_2"),
960+
961+
WithExcludeSchemas("schema_3"),
962+
963+
WithExcludeSchemas("schema_4"),
964+
WithExcludeSchemas("schema_4"),
965+
},
966+
expectedErrContains: "are both included and excluded",
967+
},
904968
}
905969
)
906970

@@ -948,7 +1012,12 @@ func runTestCase(t *testing.T, engine *pgengine.Engine, testCase *testCase, getD
9481012
if testCase.expectedErrIs != nil {
9491013
require.ErrorIs(t, err, testCase.expectedErrIs)
9501014
return
951-
} else {
1015+
}
1016+
if testCase.expectedErrContains != "" {
1017+
require.ErrorContains(t, err, testCase.expectedErrContains)
1018+
return
1019+
}
1020+
if testCase.expectedErrIs == nil && testCase.expectedErrContains == "" {
9521021
require.NoError(t, err)
9531022
}
9541023

0 commit comments

Comments
 (0)