Skip to content

Commit 2ff0799

Browse files
authored
Support PostgreSQL double qouted columns (#590)
* zhars/support_postgres_double_qouted_columns Improve SQLParser/(added support of double-quoted column names for PostgreSQL)/fixed searchable encryption mapping with double-quoted tables.
1 parent fc7b491 commit 2ff0799

File tree

8 files changed

+1996
-1802
lines changed

8 files changed

+1996
-1802
lines changed

encryptor/searchable_query_filter.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (filter *SearchableQueryFilter) filterInterestingTables(fromExp sqlparser.T
9494
// And even then, we can work only with tables that we have an encryption schema for.
9595
var encryptableTables []*AliasedTableName
9696
for _, table := range tables {
97-
if v := filter.schemaStore.GetTableSchema(table.TableName.Name.String()); v != nil {
97+
if v := filter.schemaStore.GetTableSchema(table.TableName.Name.ValueForConfig()); v != nil {
9898
encryptableTables = append(encryptableTables, table)
9999
}
100100
}
@@ -174,9 +174,9 @@ func (filter *SearchableQueryFilter) filterComparisons(exprs []*sqlparser.Compar
174174

175175
func (filter *SearchableQueryFilter) getTableSchemaOfColumn(column *sqlparser.ColName, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) config.TableSchema {
176176
if column.Qualifier.Qualifier.IsEmpty() {
177-
return filter.schemaStore.GetTableSchema(defaultTable.TableName.Name.String())
177+
return filter.schemaStore.GetTableSchema(defaultTable.TableName.Name.ValueForConfig())
178178
}
179-
tableName := aliasedTables[column.Qualifier.Name.String()]
179+
tableName := aliasedTables[column.Qualifier.Name.ValueForConfig()]
180180
return filter.schemaStore.GetTableSchema(tableName)
181181
}
182182

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package encryptor
2+
3+
import (
4+
"testing"
5+
6+
"github.com/cossacklabs/acra/encryptor/config"
7+
"github.com/cossacklabs/acra/sqlparser"
8+
)
9+
10+
func TestGetTableSchemaOfColumnMatchConfigTable(t *testing.T) {
11+
tableNameUpperCase := "SomeTableInUpperCase"
12+
configStr := `
13+
schemas:
14+
- table: sometableinuppercase
15+
encrypted:
16+
- column: "default_client_id"
17+
- column: specified_client_id
18+
client_id: specified_client_id
19+
`
20+
schemaStore, err := config.MapTableSchemaStoreFromConfig([]byte(configStr))
21+
if err != nil {
22+
t.Fatalf("Can't parse config: %s", err.Error())
23+
}
24+
25+
searchableQueryFilter := SearchableQueryFilter{
26+
schemaStore: schemaStore,
27+
}
28+
29+
tableNamesWithQuotes := sqlparser.NewTableIdentWithQuotes(tableNameUpperCase, '"')
30+
schemaTable := searchableQueryFilter.getTableSchemaOfColumn(&sqlparser.ColName{}, &AliasedTableName{
31+
TableName: sqlparser.TableName{
32+
Name: tableNamesWithQuotes,
33+
},
34+
}, AliasToTableMap{})
35+
36+
if schemaTable == nil {
37+
t.Fatalf("Expect not nil schemaTable, matched with config")
38+
}
39+
}
40+
41+
func TestFilterInterestingTables(t *testing.T) {
42+
tableNameUpperCase := "SomeTableInUpperCase"
43+
configStr := `
44+
schemas:
45+
- table: sometableinuppercase
46+
encrypted:
47+
- column: "default_client_id"
48+
- column: specified_client_id
49+
client_id: specified_client_id
50+
`
51+
schemaStore, err := config.MapTableSchemaStoreFromConfig([]byte(configStr))
52+
if err != nil {
53+
t.Fatalf("Can't parse config: %s", err.Error())
54+
}
55+
56+
searchableQueryFilter := SearchableQueryFilter{
57+
schemaStore: schemaStore,
58+
}
59+
60+
tableNamesWithQuotes := sqlparser.NewTableIdentWithQuotes(tableNameUpperCase, '"')
61+
62+
aliasedTable, _ := searchableQueryFilter.filterInterestingTables(sqlparser.TableExprs{
63+
&sqlparser.AliasedTableExpr{
64+
Expr: sqlparser.TableName{
65+
Name: tableNamesWithQuotes,
66+
},
67+
},
68+
})
69+
70+
if aliasedTable == nil {
71+
t.Fatalf("Expect not nil aliasedTable, matched with config")
72+
}
73+
}

encryptor/utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ func mapColumnsToAliases(selectQuery *sqlparser.Select) ([]*columnInfo, error) {
359359
if ok {
360360
if len(joinTables) > 0 {
361361
if !starExpr.TableName.Name.IsEmpty() {
362-
joinTable, ok := joinAliases[starExpr.TableName.Name.String()]
362+
joinTable, ok := joinAliases[starExpr.TableName.Name.ValueForConfig()]
363363
if !ok {
364364
return nil, errUnsupportedExpression
365365
}

sqlparser/ast_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/cossacklabs/acra/sqlparser/dialect/mysql"
2424
"github.com/cossacklabs/acra/sqlparser/dialect/postgresql"
2525
"reflect"
26+
"strconv"
2627
"strings"
2728
"testing"
2829
"unsafe"
@@ -535,6 +536,31 @@ func TestTableIdentMarshal(t *testing.T) {
535536
}
536537
}
537538

539+
func TestTableIdentValueForConfig(t *testing.T) {
540+
str := TableIdent{
541+
quote: 34,
542+
v: "table",
543+
}
544+
got := String(str)
545+
want := `"table"`
546+
if got != want {
547+
t.Errorf("json.Marshal()= %s, want %s", got, want)
548+
}
549+
tableForConfig := str.ValueForConfig()
550+
if tableForConfig == got {
551+
t.Errorf("ValueForConfig should not be equal with init %s, want %s", got, want)
552+
}
553+
554+
unquoted, err := strconv.Unquote(got)
555+
if err != nil {
556+
t.Fatal(err)
557+
}
558+
559+
if tableForConfig != unquoted {
560+
t.Errorf("ValueForConfig should be equal with unquoted value %s, want %s", got, want)
561+
}
562+
}
563+
538564
func TestHexDecode(t *testing.T) {
539565
testcase := []struct {
540566
in, out string

sqlparser/parse_test.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,33 @@ var (
211211
output: "select /* string table alias */ 1 from t as 't1'",
212212
// mysql allow to use single quote for column/table aliases
213213
dialect: mysql.NewMySQLDialect(),
214+
}, {
215+
input: `select * from mytable where "AGE" = 1 and "TEST" = 'test'`,
216+
// postgres allow to use double quote string for columns
217+
dialect: postgresql.NewPostgreSQLDialect(),
218+
}, {
219+
// this is valid query ONLY for MySQL in default mode, for now,
220+
// but invalid for PostgreSQL and MySQL in ANSI mode and maybe be changed in future
221+
input: `insert into some_table(id, data) VALUES (10918, "test")`,
222+
output: `insert into some_table(id, data) values (10918, 'test')`,
223+
}, {
224+
input: `insert into some_table(id, data) values (10918, 'test')`,
225+
}, {
226+
input: `select * from mytable where "test" = "test"`,
227+
output: `select * from mytable where 'test' = 'test'`,
228+
}, {
229+
input: `select * from mytable where "test" = 1 and 'value' = 'value'`,
230+
output: `select * from mytable where 'test' = 1 and 'value' = 'value'`,
231+
}, {
232+
input: `SELECT "id", "landline_number" AS "landlineNumber", "removal" FROM "users" AS "User" where "User"."is_active"`,
233+
output: `select "id", "landline_number" as "landlineNumber", "removal" from "users" as "User" where "User"."is_active"`,
234+
dialect: postgresql.NewPostgreSQLDialect(),
235+
}, {
236+
input: `select "id" from "users" as "User" where "User"."AGE" = 123`,
237+
dialect: postgresql.NewPostgreSQLDialect(),
238+
}, {
239+
input: `select "id" from "users" as "User" where "AGE" = '123'`,
240+
dialect: postgresql.NewPostgreSQLDialect(),
214241
}, {
215242
input: "select /* string table alias without as */ 1 from t 't1'",
216243
output: "select /* string table alias without as */ 1 from t as 't1'",
@@ -1322,11 +1349,17 @@ var (
13221349
)
13231350

13241351
func TestValid(t *testing.T) {
1352+
var testDialect dialect.Dialect
13251353
for i, tcase := range validSQL {
13261354
if tcase.output == "" {
13271355
tcase.output = tcase.input
13281356
}
1329-
tree, err := New(ModeStrict).Parse(tcase.input)
1357+
1358+
testDialect = tcase.dialect
1359+
if tcase.dialect == nil {
1360+
testDialect = mysql.NewMySQLDialect()
1361+
}
1362+
tree, err := ParseWithDialect(testDialect, tcase.input)
13301363
if err != nil {
13311364
t.Errorf("Parse(%q) err: %v, want nil", tcase.input, err)
13321365
continue
@@ -1656,6 +1689,9 @@ func TestConvert(t *testing.T) {
16561689
input: "select convert('abc', decimal(4+9)) from t",
16571690
output: "syntax error at position 33",
16581691
},
1692+
// TODO: added test cases to cover errors for MySQL ANSI mode
1693+
// `insert into table (id, name) values (125, "data")` currently, in ANSI mod its valid query with contains
1694+
// Rows {SQLVal(125), ColName("data")} - but it should fail with error
16591695
}
16601696

16611697
var dialect dialect.Dialect

0 commit comments

Comments
 (0)