@@ -185,19 +185,20 @@ func TrimStringToN(query string, n int) string {
185185 return query [:n ]
186186}
187187
188- // NormalizeAndRedactSQLQuery returns a normalized (lowercases SQL commands) SQL string,
188+ // HandleRawSQLQuery returns a normalized (lowercases SQL commands) SQL string,
189189// and redacted SQL string with the params stripped out for display.
190190// Taken from sqlparser package
191- func NormalizeAndRedactSQLQuery (sql string ) (normalizedQuery string , redactedQuery string , error error ) {
191+ func HandleRawSQLQuery (sql string ) (normalizedQuery , redactedQuery string , parsedQuery sqlparser. Statement , err error ) {
192192 bv := map [string ]* querypb.BindVariable {}
193193 sqlStripped , _ := sqlparser .SplitMarginComments (sql )
194194
195195 // sometimes queries might have ; at the end, that should be stripped
196196 sqlStripped = strings .TrimSuffix (sqlStripped , ";" )
197197
198198 stmt , err := sqlparser .Parse (sqlStripped )
199+ outputStmt , _ := sqlparser .Parse (sqlStripped )
199200 if err != nil {
200- return "" , "" , err
201+ return "" , "" , nil , ErrQuerySyntaxError
201202 }
202203
203204 normalizedQ := sqlparser .String (stmt )
@@ -206,5 +207,114 @@ func NormalizeAndRedactSQLQuery(sql string) (normalizedQuery string, redactedQue
206207 sqlparser .Normalize (stmt , bv , ValueMask )
207208 redactedQ := sqlparser .String (stmt )
208209
209- return normalizedQ , redactedQ , nil
210+ return normalizedQ , redactedQ , outputStmt , nil
211+ }
212+
213+ // CheckPatternsMatching evaluates if parsed query matches specified set of patterns
214+ func CheckPatternsMatching (patterns []sqlparser.Statement , parsedQuery sqlparser.Statement ) bool {
215+ for _ , pattern := range patterns {
216+ if checkSinglePatternMatch (parsedQuery , pattern ) {
217+ return true
218+ }
219+ }
220+ return false
221+ }
222+
223+ // CheckExactQueriesMatch evaluates if query presents in set of queries
224+ func CheckExactQueriesMatch (normalizedQuery string , setOfQueries map [string ]bool ) bool {
225+ if ! setOfQueries [normalizedQuery ] {
226+ return false
227+ }
228+ return true
229+ }
230+
231+ // CheckTableNamesMatch evaluates if query contains table presented in specified set of tables
232+ func CheckTableNamesMatch (parsedQuery sqlparser.Statement , setOfTables map [string ]bool ) (bool , bool ) {
233+ atLeastOneTableNameMatch := false
234+ allTableNamesMatch := false
235+
236+ switch parsedQuery .(type ) {
237+ case * sqlparser.Select :
238+ selectQuery := parsedQuery .(* sqlparser.Select )
239+ atLeastOneTableNameMatch , allTableNamesMatch = checkTableExprsMatch (selectQuery .From , setOfTables )
240+ break
241+ case * sqlparser.Insert :
242+ insertQuery := parsedQuery .(* sqlparser.Insert )
243+ if setOfTables [insertQuery .Table .Name .String ()] {
244+ atLeastOneTableNameMatch = true
245+ allTableNamesMatch = true
246+ } else {
247+ atLeastOneTableNameMatch = false
248+ allTableNamesMatch = false
249+ }
250+ break
251+ default :
252+ //TODO other query types
253+ return false , false
254+ }
255+
256+ return atLeastOneTableNameMatch , allTableNamesMatch
257+ }
258+
259+ // Tables matchers
260+ func checkTableExprsMatch (tables sqlparser.TableExprs , setOfTables map [string ]bool ) (bool , bool ) {
261+ oneTableMatch := false
262+ allTablesMatch := false
263+ counter := 0
264+ for _ , tableExpr := range tables {
265+ oneTableMatchInternal , allTablesMatchInternal := checkTableExprMatch (tableExpr , setOfTables )
266+ if oneTableMatchInternal {
267+ oneTableMatch = true
268+ if allTablesMatchInternal {
269+ counter ++
270+ }
271+ }
272+ }
273+ if counter == len (tables ) {
274+ allTablesMatch = true
275+ }
276+ return oneTableMatch , allTablesMatch
277+ }
278+
279+ func checkTableExprMatch (table sqlparser.TableExpr , setOfTables map [string ]bool ) (bool , bool ) {
280+ oneTableMatch := false
281+ allTablesMatch := false
282+
283+ switch table .(type ) {
284+ case * sqlparser.AliasedTableExpr :
285+ oneTableMatch , allTablesMatch = checkAliasedTable (table .(* sqlparser.AliasedTableExpr ), setOfTables )
286+ case * sqlparser.JoinTableExpr :
287+ oneTableMatch , allTablesMatch = checkJoinedTable (table .(* sqlparser.JoinTableExpr ), setOfTables )
288+ case * sqlparser.ParenTableExpr :
289+ oneTableMatch , allTablesMatch = checkParenTable (table .(* sqlparser.ParenTableExpr ), setOfTables )
290+ }
291+ return oneTableMatch , allTablesMatch
292+ }
293+
294+ func checkAliasedTable (table * sqlparser.AliasedTableExpr , setOfTables map [string ]bool ) (bool , bool ) {
295+ if setOfTables [sqlparser .String (table .Expr )] {
296+ return true , true
297+ }
298+ return false , false
299+ }
300+
301+ func checkJoinedTable (table * sqlparser.JoinTableExpr , setOfTables map [string ]bool ) (bool , bool ) {
302+ oneTableMatch := false
303+ allTablesMatch := false
304+
305+ oneLeftTableMatchInternal , allLeftTablesMatchInternal := checkTableExprMatch (table .LeftExpr , setOfTables )
306+ oneRightTableMatchInternal , allRightTablesMatchInternal := checkTableExprMatch (table .RightExpr , setOfTables )
307+
308+ if oneLeftTableMatchInternal || oneRightTableMatchInternal {
309+ oneTableMatch = true
310+ }
311+ if allLeftTablesMatchInternal && allRightTablesMatchInternal {
312+ allTablesMatch = true
313+ }
314+ return oneTableMatch , allTablesMatch
315+ }
316+
317+ func checkParenTable (table * sqlparser.ParenTableExpr , setOfTables map [string ]bool ) (bool , bool ) {
318+ singleTableMatch , allTablesMatch := checkTableExprsMatch (table .Exprs , setOfTables )
319+ return singleTableMatch , allTablesMatch
210320}
0 commit comments