diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index a9c0fa3d3..6d1fb0b54 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -36,5 +36,6 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &filterLen{}) Walk(node, &filterLast{}) Walk(node, &filterFirst{}) + Walk(node, &predicateCombination{}) return nil } diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index e45de763b..703bd1ceb 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -1,6 +1,7 @@ package optimizer_test import ( + "fmt" "reflect" "strings" "testing" @@ -339,3 +340,124 @@ func TestOptimize_filter_map_first(t *testing.T) { assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) } + +func TestOptimize_predicate_combination(t *testing.T) { + tests := []struct { + op string + fn string + wantOp string + }{ + {"and", "all", "and"}, + {"&&", "all", "&&"}, + {"or", "all", "or"}, + {"||", "all", "||"}, + {"and", "any", "and"}, + {"&&", "any", "&&"}, + {"or", "any", "or"}, + {"||", "any", "||"}, + {"and", "none", "or"}, + {"&&", "none", "||"}, + {"and", "one", "or"}, + {"&&", "one", "||"}, + } + + for _, tt := range tests { + rule := fmt.Sprintf(`%s(users, .Age > 18 and .Name != "Bob") %s %s(users, .Age < 30)`, tt.fn, tt.op, tt.fn) + t.Run(rule, func(t *testing.T) { + tree, err := parser.Parse(rule) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: tt.fn, + Arguments: []ast.Node{ + &ast.IdentifierNode{Value: "users"}, + &ast.ClosureNode{ + Node: &ast.BinaryNode{ + Operator: tt.wantOp, + Left: &ast.BinaryNode{ + Operator: "and", + Left: &ast.BinaryNode{ + Operator: ">", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 18}, + }, + Right: &ast.BinaryNode{ + Operator: "!=", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Name"}, + }, + Right: &ast.StringNode{Value: "Bob"}, + }, + }, + Right: &ast.BinaryNode{ + Operator: "<", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 30}, + }, + }, + }, + }, + } + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) + }) + } +} + +func TestOptimize_predicate_combination_nested(t *testing.T) { + tree, err := parser.Parse(`any(users, {all(.Friends, {.Age == 18 })}) && any(users, {all(.Friends, {.Name != "Bob" })})`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: "any", + Arguments: []ast.Node{ + &ast.IdentifierNode{Value: "users"}, + &ast.ClosureNode{ + Node: &ast.BuiltinNode{ + Name: "all", + Arguments: []ast.Node{ + &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Friends"}, + }, + &ast.ClosureNode{ + Node: &ast.BinaryNode{ + Operator: "&&", + Left: &ast.BinaryNode{ + Operator: "==", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 18}, + }, + Right: &ast.BinaryNode{ + Operator: "!=", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Name"}, + }, + Right: &ast.StringNode{Value: "Bob"}, + }, + }, + }, + }, + }, + }, + }, + } + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go new file mode 100644 index 000000000..2733781df --- /dev/null +++ b/optimizer/predicate_combination.go @@ -0,0 +1,51 @@ +package optimizer + +import ( + . "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/parser/operator" +) + +type predicateCombination struct{} + +func (v *predicateCombination) Visit(node *Node) { + if op, ok := (*node).(*BinaryNode); ok && operator.IsBoolean(op.Operator) { + if left, ok := op.Left.(*BuiltinNode); ok { + if combinedOp, ok := combinedOperator(left.Name, op.Operator); ok { + if right, ok := op.Right.(*BuiltinNode); ok && right.Name == left.Name { + if left.Arguments[0].Type() == right.Arguments[0].Type() && left.Arguments[0].String() == right.Arguments[0].String() { + closure := &ClosureNode{ + Node: &BinaryNode{ + Operator: combinedOp, + Left: left.Arguments[1].(*ClosureNode).Node, + Right: right.Arguments[1].(*ClosureNode).Node, + }, + } + v.Visit(&closure.Node) + Patch(node, &BuiltinNode{ + Name: left.Name, + Arguments: []Node{ + left.Arguments[0], + closure, + }, + }) + } + } + } + } + } +} + +func combinedOperator(fn, op string) (string, bool) { + switch fn { + case "all", "any": + return op, true + case "one", "none": + switch op { + case "and": + return "or", true + case "&&": + return "||", true + } + } + return "", false +}