diff --git a/ast/node.go b/ast/node.go index dfdf715fc..d36ebe107 100644 --- a/ast/node.go +++ b/ast/node.go @@ -234,3 +234,11 @@ type PairNode struct { Key Node // Key of the pair. Value Node // Value of the pair. } + +// CompareNode represents comparison +type CompareNode struct { + base + Left Node // Left represents the left-hand side of the comparison operation + Operators []string // Operators is a list of comparison operator tokens used in the comparison. + Comparators []Node // Comparators representing the right-hand sides of the comparison operation +} diff --git a/ast/print.go b/ast/print.go index 6a7d698a9..f19c59356 100644 --- a/ast/print.go +++ b/ast/print.go @@ -219,3 +219,34 @@ func (n *PairNode) String() string { } return fmt.Sprintf("(%s): %s", n.Key.String(), n.Value.String()) } + +func (n *CompareNode) string(node Node) string { + switch v := node.(type) { + case *BinaryNode, *CompareNode: + return fmt.Sprintf("(%s)", v) + default: + return v.String() + } +} + +func (n *CompareNode) String() string { + var builder strings.Builder + builder.WriteString(n.string(n.Left)) + opIdx := 0 + for i := 0; i < len(n.Comparators); i++ { + if op := n.Operators[opIdx]; op != "&&" { + builder.WriteByte(' ') + builder.WriteString(op) + if op == "not" { + opIdx++ + builder.WriteByte(' ') + builder.WriteString(n.Operators[opIdx]) + } + builder.WriteByte(' ') + builder.WriteString(n.string(n.Comparators[i])) + } + opIdx++ + } + + return builder.String() +} diff --git a/ast/print_test.go b/ast/print_test.go index 51edd63f5..a022219c8 100644 --- a/ast/print_test.go +++ b/ast/print_test.go @@ -41,7 +41,7 @@ func TestPrint(t *testing.T) { {`a == b`, `a == b`}, {`a matches b`, `a matches b`}, {`a in b`, `a in b`}, - {`a not in b`, `not (a in b)`}, + {`a not in b`, `a not in b`}, {`a and b`, `a and b`}, {`a or b`, `a or b`}, {`a or b and c`, `a or (b and c)`}, diff --git a/ast/visitor.go b/ast/visitor.go index 90bc9f1d0..5722f6c12 100644 --- a/ast/visitor.go +++ b/ast/visitor.go @@ -66,6 +66,11 @@ func Walk(node *Node, v Visitor) { case *PairNode: Walk(&n.Key, v) Walk(&n.Value, v) + case *CompareNode: + Walk(&n.Left, v) + for i := range n.Comparators { + Walk(&n.Comparators[i], v) + } default: panic(fmt.Sprintf("undefined node type (%T)", node)) } diff --git a/checker/checker.go b/checker/checker.go index fae8f5a16..95d02939f 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -166,6 +166,8 @@ func (v *checker) visit(node ast.Node) Nature { nt = v.MapNode(n) case *ast.PairNode: nt = v.PairNode(n) + case *ast.CompareNode: + nt = v.CompareNode(n) default: panic(fmt.Sprintf("undefined node type (%T)", node)) } @@ -282,11 +284,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { r = r.Deref() switch node.Operator { - case "==", "!=": - if isComparable(l, r) { - return boolNature - } - case "or", "||", "and", "&&": if isBool(l) && isBool(r) { return boolNature @@ -295,20 +292,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { return boolNature } - case "<", ">", ">=", "<=": - if isNumber(l) && isNumber(r) { - return boolNature - } - if isString(l) && isString(r) { - return boolNature - } - if isTime(l) && isTime(r) { - return boolNature - } - if or(l, r, isNumber, isString, isTime) { - return boolNature - } - case "-": if isNumber(l) && isNumber(r) { return combined(l, r) @@ -372,51 +355,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { return unknown } - case "in": - if (isString(l) || isUnknown(l)) && isStruct(r) { - return boolNature - } - if isMap(r) { - if !isUnknown(l) && !l.AssignableTo(r.Key()) { - return v.error(node, "cannot use %v as type %v in map key", l, r.Key()) - } - return boolNature - } - if isArray(r) { - if !isComparable(l, r.Elem()) { - return v.error(node, "cannot use %v as type %v in array", l, r.Elem()) - } - return boolNature - } - if isUnknown(l) && anyOf(r, isString, isArray, isMap) { - return boolNature - } - if isUnknown(r) { - return boolNature - } - - case "matches": - if s, ok := node.Right.(*ast.StringNode); ok { - _, err := regexp.Compile(s.Value) - if err != nil { - return v.error(node, err.Error()) - } - } - if isString(l) && isString(r) { - return boolNature - } - if or(l, r, isString) { - return boolNature - } - - case "contains", "startsWith", "endsWith": - if isString(l) && isString(r) { - return boolNature - } - if or(l, r, isString) { - return boolNature - } - case "..": if isInteger(l) && isInteger(r) { return Nature{ @@ -1230,3 +1168,87 @@ func (v *checker) PairNode(node *ast.PairNode) Nature { v.visit(node.Value) return nilNature } + +func (v *checker) CompareNode(node *ast.CompareNode) Nature { + nodeLeft := node.Left + opIdx := 0 + operatorOverride := false + for i, comparator := range node.Comparators { + op := node.Operators[opIdx] + if negate := op == "not"; negate { + opIdx++ + op = node.Operators[opIdx] + } + if op == "&&" { + if !operatorOverride { + operatorOverride = true + } + } else if err := v.compareNode(op, nodeLeft, comparator, i); err != nil { + return v.error(comparator, err.Error()) + } + opIdx++ + nodeLeft = comparator + } + if operatorOverride { + return unknown + } + return boolNature +} + +func (v *checker) compareNode(op string, nodeLeft, nodeRight ast.Node, index int) error { + l := v.visit(nodeLeft) + r := v.visit(nodeRight) + + l = l.Deref() + r = r.Deref() + switch op { + case "==", "!=": + if (isBool(r) && index > 0) || isComparable(l, r) { + return nil + } + case "<", ">", ">=", "<=": + if isNumber(l) && isNumber(r) || + isString(l) && isString(r) || + isTime(l) && isTime(r) || + or(l, r, isNumber, isString, isTime) { + return nil + } + case "in": + if (isString(l) || isUnknown(l)) && isStruct(r) { + return nil + } + if isMap(r) { + if !isUnknown(l) && !l.AssignableTo(r.Key()) { + return fmt.Errorf("cannot use %v as type %v in map key", l, r.Key()) + } + return nil + } + if isArray(r) { + if !isComparable(l, r.Elem()) { + return fmt.Errorf("cannot use %v as type %v in array", l, r.Elem()) + } + return nil + } + if (isUnknown(l) && anyOf(r, isString, isArray, isMap)) || isUnknown(r) { + return nil + } + + case "matches": + if s, ok := nodeRight.(*ast.StringNode); ok { + if _, err := regexp.Compile(s.Value); err != nil { + return err + } + } + if (isString(l) && isString(r)) || or(l, r, isString) { + return nil + } + case "contains", "startsWith", "endsWith": + if isString(l) && isString(r) || + or(l, r, isString) { + return nil + } + default: + return fmt.Errorf("unknown operator (%v)", op) + } + return fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, op, l, r) +} diff --git a/compiler/compiler.go b/compiler/compiler.go index 5e993540d..329996057 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -273,6 +273,8 @@ func (c *compiler) compile(node ast.Node) { c.MapNode(n) case *ast.PairNode: c.PairNode(n) + case *ast.CompareNode: + c.CompareNode(n) default: panic(fmt.Sprintf("undefined node type (%T)", node)) } @@ -454,34 +456,6 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) { c.derefInNeeded(node.Right) c.patchJump(end) - case "<": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpLess) - - case ">": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpMore) - - case "<=": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpLessOrEqual) - - case ">=": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpMoreOrEqual) - case "+": c.compile(node.Left) c.derefInNeeded(node.Left) @@ -524,51 +498,6 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) { c.derefInNeeded(node.Right) c.emit(OpExponent) - case "in": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpIn) - - case "matches": - if str, ok := node.Right.(*ast.StringNode); ok { - re, err := regexp.Compile(str.Value) - if err != nil { - panic(err) - } - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.emit(OpMatchesConst, c.addConstant(re)) - } else { - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpMatches) - } - - case "contains": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpContains) - - case "startsWith": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpStartsWith) - - case "endsWith": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpEndsWith) - case "..": c.compile(node.Left) c.derefInNeeded(node.Left) @@ -1198,6 +1127,113 @@ func (c *compiler) PairNode(node *ast.PairNode) { c.compile(node.Value) } +func (c *compiler) compileAndDeref(node ast.Node) { + c.compile(node) + c.derefInNeeded(node) +} + +func (c *compiler) CompareNode(node *ast.CompareNode) { + var end int + jump := false + nodeLeft := node.Left + comparatorSize := len(node.Comparators) - 1 + opIdx := 0 + for i := 0; i <= comparatorSize; i++ { + comparator := node.Comparators[i] + op := node.Operators[opIdx] + negate := op == "not" + if negate { + opIdx++ + op = node.Operators[opIdx] + } + switch op { + case "&&": + if i == 0 { + c.compileAndDeref(nodeLeft) + } else { + c.compileAndDeref(comparator) + } + case "matches": + c.compileAndDeref(nodeLeft) + if str, ok := comparator.(*ast.StringNode); ok { + re, err := regexp.Compile(str.Value) + if err != nil { + panic(err) + } + c.emit(OpMatchesConst, c.addConstant(re)) + } else { + c.compileAndDeref(comparator) + c.emit(OpMatches) + } + case "==", "!=": + r := kind(comparator.Type()) + if i != 0 && r == reflect.Bool { + c.bytecode = c.bytecode[:len(c.bytecode)-2] + c.arguments = c.arguments[:len(c.arguments)-2] + c.locations = c.locations[:len(c.locations)-2] + jump = false + c.compileAndDeref(comparator) + c.emit(OpEqual) + } else { + c.compileAndDeref(nodeLeft) + c.compileAndDeref(comparator) + l := kind(nodeLeft.Type()) + if l == r && isSimpleType(nodeLeft) && isSimpleType(comparator) { + switch l { + case reflect.Int: + c.emit(OpEqualInt) + case reflect.String: + c.emit(OpEqualString) + default: + c.emit(OpEqual) + } + } else { + c.emit(OpEqual) + } + } + if op == "!=" { + c.emit(OpNot) + } + default: + c.compileAndDeref(nodeLeft) + c.compileAndDeref(comparator) + switch op { + case "<": + c.emit(OpLess) + case ">": + c.emit(OpMore) + case "<=": + c.emit(OpLessOrEqual) + case ">=": + c.emit(OpMoreOrEqual) + case "in": + c.emit(OpIn) + case "contains": + c.emit(OpContains) + case "startsWith": + c.emit(OpStartsWith) + case "endsWith": + c.emit(OpEndsWith) + default: + panic(fmt.Sprintf("unknown operator (%v)", op)) + } + } + if negate { + c.emit(OpNot) + } + if jump { + c.patchJump(end) + } + if i < comparatorSize { + end = c.emit(OpJumpIfFalse, placeholder) + c.emit(OpPop) + jump = true + } + nodeLeft = comparator + opIdx++ + } +} + func (c *compiler) derefInNeeded(node ast.Node) { switch node.Type().Kind() { case reflect.Ptr, reflect.Interface: diff --git a/optimizer/in_array.go b/optimizer/in_array.go index e91320c0f..62c91c09f 100644 --- a/optimizer/in_array.go +++ b/optimizer/in_array.go @@ -10,22 +10,33 @@ type inArray struct{} func (*inArray) Visit(node *Node) { switch n := (*node).(type) { - case *BinaryNode: - if n.Operator == "in" { - if array, ok := n.Right.(*ArrayNode); ok { - if len(array.Nodes) > 0 { - t := n.Left.Type() - if t == nil || t.Kind() != reflect.Int { - // This optimization can be only performed if left side is int type, - // as runtime.in func uses reflect.Map.MapIndex and keys of map must, - // be same as checked value type. + case *CompareNode: + for i := 0; i < len(n.Operators); i++ { + op := n.Operators[i] + negate := op == "not" + if negate { + i++ + op = n.Operators[i] + } + + if op == "in" { + comparatorIdx := i + if negate { + comparatorIdx = i - 1 + } + if array, ok := n.Comparators[comparatorIdx].(*ArrayNode); ok && len(array.Nodes) > 0 { + var lType reflect.Type + if comparatorIdx == 0 { + lType = n.Left.Type() + } else { + lType = n.Comparators[comparatorIdx-1].Type() + } + if lType == nil || lType.Kind() != reflect.Int { goto string } - for _, a := range array.Nodes { - if _, ok := a.(*IntegerNode); !ok { - goto string - } + if !allIntegerNodes(array.Nodes) { + goto string } { value := make(map[int]struct{}) @@ -34,18 +45,12 @@ func (*inArray) Visit(node *Node) { } m := &ConstantNode{Value: value} m.SetType(reflect.TypeOf(value)) - patchCopyType(node, &BinaryNode{ - Operator: n.Operator, - Left: n.Left, - Right: m, - }) + n.Comparators[comparatorIdx] = m } string: - for _, a := range array.Nodes { - if _, ok := a.(*StringNode); !ok { - return - } + if !allStringNodes(array.Nodes) { + continue } { value := make(map[string]struct{}) @@ -54,11 +59,7 @@ func (*inArray) Visit(node *Node) { } m := &ConstantNode{Value: value} m.SetType(reflect.TypeOf(value)) - patchCopyType(node, &BinaryNode{ - Operator: n.Operator, - Left: n.Left, - Right: m, - }) + n.Comparators[comparatorIdx] = m } } @@ -66,3 +67,21 @@ func (*inArray) Visit(node *Node) { } } } + +func allIntegerNodes(nodes []Node) bool { + for _, n := range nodes { + if _, ok := n.(*IntegerNode); !ok { + return false + } + } + return true +} + +func allStringNodes(nodes []Node) bool { + for _, n := range nodes { + if _, ok := n.(*StringNode); !ok { + return false + } + } + return true +} diff --git a/optimizer/in_range.go b/optimizer/in_range.go index ed2f557ea..6e5771cb0 100644 --- a/optimizer/in_range.go +++ b/optimizer/in_range.go @@ -10,31 +10,78 @@ type inRange struct{} func (*inRange) Visit(node *Node) { switch n := (*node).(type) { - case *BinaryNode: - if n.Operator == "in" { - t := n.Left.Type() - if t == nil { - return + case *CompareNode: + opSize := len(n.Operators) + for i := 0; i < opSize; i++ { + op := n.Operators[i] + negate := op == "not" + if negate { + i++ + op = n.Operators[i] } - if t.Kind() != reflect.Int { - return - } - if rangeOp, ok := n.Right.(*BinaryNode); ok && rangeOp.Operator == ".." { - if from, ok := rangeOp.Left.(*IntegerNode); ok { - if to, ok := rangeOp.Right.(*IntegerNode); ok { - patchCopyType(node, &BinaryNode{ - Operator: "and", - Left: &BinaryNode{ - Operator: ">=", - Left: n.Left, - Right: from, - }, - Right: &BinaryNode{ - Operator: "<=", - Left: n.Left, - Right: to, - }, - }) + if op == "in" { + comparatorIdx := i + if negate { + comparatorIdx = i - 1 + } + if rangeOp, ok := n.Comparators[comparatorIdx].(*BinaryNode); ok && rangeOp.Operator == ".." { + if from, ok := rangeOp.Left.(*IntegerNode); ok { + if to, ok := rangeOp.Right.(*IntegerNode); ok { + var lNode Node + if comparatorIdx == 0 { + lNode = n.Left + } else { + lNode = n.Comparators[comparatorIdx-1] + } + if lType := lNode.Type(); lType != nil && lType.Kind() == reflect.Int { + if comparatorIdx == 0 { + if len(n.Comparators) == 1 { + n.Operators = []string{"<=", "<="} + n.Comparators = []Node{lNode, to} + n.Left = from + if negate { + patchCopyType(node, &UnaryNode{ + Operator: "not", + Node: n, + }) + } + return + } else { + n.Operators[comparatorIdx] = "&&" + n.Left = &CompareNode{ + Left: from, + Operators: []string{"<=", "<="}, + Comparators: []Node{lNode, to}, + } + n.Comparators = n.Comparators[comparatorIdx+1:] + if negate { + n.Left = &UnaryNode{ + Operator: "not", + Node: n.Left, + } + n.Operators = append(n.Operators[:comparatorIdx+1], n.Operators[comparatorIdx+2:]...) + opSize-- + } + } + } else { + n.Operators[comparatorIdx] = "&&" + var comparator Node = &CompareNode{ + Left: from, + Operators: []string{"<=", "<="}, + Comparators: []Node{lNode, to}, + } + if negate { + comparator = &UnaryNode{ + Operator: "not", + Node: comparator, + } + n.Operators = append(n.Operators[:comparatorIdx+1], n.Operators[comparatorIdx+2:]...) + opSize-- + } + n.Comparators[comparatorIdx] = comparator + } + } + } } } } diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index 56a890492..0cf4b7592 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -103,47 +103,66 @@ func TestOptimize_in_array(t *testing.T) { err = optimizer.Optimize(&tree.Node, nil) require.NoError(t, err) - expected := &ast.BinaryNode{ - Operator: "in", - Left: &ast.IdentifierNode{Value: "v"}, - Right: &ast.ConstantNode{Value: map[int]struct{}{1: {}, 2: {}, 3: {}}}, + expected := &ast.CompareNode{ + Left: &ast.IdentifierNode{Value: "v"}, + Operators: []string{"in"}, + Comparators: []ast.Node{&ast.ConstantNode{Value: map[int]struct{}{1: {}, 2: {}, 3: {}}}}, } assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) } func TestOptimize_in_range(t *testing.T) { - tree, err := parser.Parse(`age in 18..31`) - require.NoError(t, err) - - config := conf.New(map[string]int{"age": 30}) - _, err = checker.Check(tree, config) - - err = optimizer.Optimize(&tree.Node, nil) - require.NoError(t, err) - - left := &ast.IdentifierNode{ - Value: "age", - } - expected := &ast.BinaryNode{ - Operator: "and", - Left: &ast.BinaryNode{ - Operator: ">=", - Left: left, - Right: &ast.IntegerNode{ - Value: 18, - }, - }, - Right: &ast.BinaryNode{ - Operator: "<=", - Left: left, - Right: &ast.IntegerNode{ - Value: 31, + tests := []struct { + code string + expected ast.Node + }{ + {`age in 18..31`, &ast.CompareNode{ + Left: &ast.IntegerNode{Value: 18}, + Operators: []string{"<=", "<="}, + Comparators: []ast.Node{&ast.IdentifierNode{Value: "age"}, &ast.IntegerNode{Value: 31}}, + }}, + {`age in 18..31 == true`, &ast.CompareNode{ + Left: &ast.CompareNode{ + Operators: []string{"<=", "<="}, + Left: &ast.IntegerNode{Value: 18}, + Comparators: []ast.Node{&ast.IdentifierNode{Value: "age"}, &ast.IntegerNode{Value: 31}}, }, - }, + Operators: []string{"&&", "=="}, + Comparators: []ast.Node{&ast.BoolNode{Value: true}}, + }}, + {`age > 19 in 18..31`, &ast.CompareNode{ + Left: &ast.IdentifierNode{Value: "age"}, + Operators: []string{">", "&&"}, + Comparators: []ast.Node{&ast.IntegerNode{Value: 19}, &ast.CompareNode{ + Operators: []string{"<=", "<="}, + Left: &ast.IntegerNode{Value: 18}, + Comparators: []ast.Node{&ast.IntegerNode{Value: 19}, &ast.IntegerNode{Value: 31}}, + }}, + }}, + {`age > 19 in 18..31 != true`, &ast.CompareNode{ + Left: &ast.IdentifierNode{Value: "age"}, + Operators: []string{">", "&&", "!="}, + Comparators: []ast.Node{&ast.IntegerNode{Value: 19}, &ast.CompareNode{ + Operators: []string{"<=", "<="}, + Left: &ast.IntegerNode{Value: 18}, + Comparators: []ast.Node{&ast.IntegerNode{Value: 19}, &ast.IntegerNode{Value: 31}}, + }, &ast.BoolNode{Value: true}}, + }}, } + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + tree, err := parser.Parse(tt.code) + require.NoError(t, err) - assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) + config := conf.New(map[string]int{"age": 30}) + _, err = checker.Check(tree, config) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + assert.Equal(t, ast.Dump(tt.expected), ast.Dump(tree.Node)) + }) + } } func TestOptimize_in_range_with_floats(t *testing.T) { @@ -185,13 +204,13 @@ func TestOptimize_filter_len(t *testing.T) { Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "==", + Node: &ast.CompareNode{ + Operators: []string{"=="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, }, @@ -212,13 +231,13 @@ func TestOptimize_filter_0(t *testing.T) { Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "==", + Node: &ast.CompareNode{ + Operators: []string{"=="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, }, @@ -240,13 +259,13 @@ func TestOptimize_filter_first(t *testing.T) { Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "==", + Node: &ast.CompareNode{ + Operators: []string{"=="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, }, @@ -268,13 +287,13 @@ func TestOptimize_filter_minus_1(t *testing.T) { Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "==", + Node: &ast.CompareNode{ + Operators: []string{"=="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, }, @@ -296,13 +315,13 @@ func TestOptimize_filter_last(t *testing.T) { Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "==", + Node: &ast.CompareNode{ + Operators: []string{"=="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, }, @@ -324,13 +343,13 @@ func TestOptimize_filter_map(t *testing.T) { Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "==", + Node: &ast.CompareNode{ + Operators: []string{"=="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, }, @@ -355,13 +374,13 @@ func TestOptimize_filter_map_first(t *testing.T) { Arguments: []ast.Node{ &ast.IdentifierNode{Value: "users"}, &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "==", + Node: &ast.CompareNode{ + Operators: []string{"=="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, }, @@ -407,30 +426,30 @@ func TestOptimize_predicate_combination(t *testing.T) { Operator: tt.wantOp, Left: &ast.BinaryNode{ Operator: "and", - Left: &ast.BinaryNode{ - Operator: ">", + Left: &ast.CompareNode{ + Operators: []string{">"}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Age"}, }, - Right: &ast.IntegerNode{Value: 18}, + Comparators: []ast.Node{&ast.IntegerNode{Value: 18}}, }, - Right: &ast.BinaryNode{ - Operator: "!=", + Right: &ast.CompareNode{ + Operators: []string{"!="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, - Right: &ast.BinaryNode{ - Operator: "<", + Right: &ast.CompareNode{ + Operators: []string{"<"}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Age"}, }, - Right: &ast.IntegerNode{Value: 30}, + Comparators: []ast.Node{&ast.IntegerNode{Value: 30}}, }, }, }, @@ -463,21 +482,21 @@ func TestOptimize_predicate_combination_nested(t *testing.T) { &ast.ClosureNode{ Node: &ast.BinaryNode{ Operator: "&&", - Left: &ast.BinaryNode{ - Operator: "==", + Left: &ast.CompareNode{ + Operators: []string{"=="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Age"}, }, - Right: &ast.IntegerNode{Value: 18}, + Comparators: []ast.Node{&ast.IntegerNode{Value: 18}}, }, - Right: &ast.BinaryNode{ - Operator: "!=", + Right: &ast.CompareNode{ + Operators: []string{"!="}, Left: &ast.MemberNode{ Node: &ast.PointerNode{}, Property: &ast.StringNode{Value: "Name"}, }, - Right: &ast.StringNode{Value: "Bob"}, + Comparators: []ast.Node{&ast.StringNode{Value: "Bob"}}, }, }, }, diff --git a/parser/operator/operator.go b/parser/operator/operator.go index 4eeaf80ed..ec99d8681 100644 --- a/parser/operator/operator.go +++ b/parser/operator/operator.go @@ -65,5 +65,6 @@ var Binary = map[string]Operator{ } func IsComparison(op string) bool { - return op == "<" || op == ">" || op == ">=" || op == "<=" + val, ok := Binary[op] + return ok && val.Precedence == 20 } diff --git a/parser/parser.go b/parser/parser.go index 77b2a700a..290392858 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -127,7 +127,6 @@ func (p *parser) expect(kind Kind, values ...string) { } // parse functions - func (p *parser) parseExpression(precedence int) Node { if precedence == 0 && p.current.Is(Operator, "let") { return p.parseVariableDeclaration() @@ -175,8 +174,8 @@ func (p *parser) parseExpression(precedence int) Node { break } - if operator.IsComparison(opToken.Value) { - nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence) + if op.Precedence == 20 { + nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence+1, negate) goto next } @@ -701,33 +700,35 @@ func (p *parser) parsePostfixExpression(node Node) Node { return node } -func (p *parser) parseComparison(left Node, token Token, precedence int) Node { - var rootNode Node +func (p *parser) parseComparison(left Node, token Token, precedence int, negate bool) Node { + compareNode := &CompareNode{ + Left: left, + Operators: make([]string, 0), + Comparators: make([]Node, 0), + } + if negate { + compareNode.Operators = append(compareNode.Operators, "not") + } for { comparator := p.parseExpression(precedence + 1) - cmpNode := &BinaryNode{ - Operator: token.Value, - Left: left, - Right: comparator, + comparator.SetLocation(token.Location) + compareNode.Operators = append(compareNode.Operators, token.Value) + compareNode.Comparators = append(compareNode.Comparators, comparator) + if p.err != nil || !p.current.Is(Operator) || (!operator.IsComparison(p.current.Value) && p.current.Value != "not") { + break } - cmpNode.SetLocation(token.Location) - if rootNode == nil { - rootNode = cmpNode - } else { - rootNode = &BinaryNode{ - Operator: "&&", - Left: rootNode, - Right: cmpNode, + if p.current.Value == "not" { + p.next() + if !operator.AllowedNegateSuffix(p.current.Value) { + p.error("unexpected token %v", p.current) + break } - rootNode.SetLocation(token.Location) + compareNode.Operators = append(compareNode.Operators, "not") } - - left = comparator token = p.current - if !(token.Is(Operator) && operator.IsComparison(token.Value) && p.err == nil) { - break - } p.next() } - return rootNode + + compareNode.SetLocation(left.Location()) + return compareNode } diff --git a/parser/parser_test.go b/parser/parser_test.go index 3c6ee5b2b..e93a3c16b 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -239,17 +239,17 @@ world`}, }, { "'a' == 'b'", - &BinaryNode{Operator: "==", - Left: &StringNode{Value: "a"}, - Right: &StringNode{Value: "b"}}, + &CompareNode{Operators: []string{"=="}, + Left: &StringNode{Value: "a"}, + Comparators: []Node{&StringNode{Value: "b"}}}, }, { "+0 != -0", - &BinaryNode{Operator: "!=", + &CompareNode{Operators: []string{"!="}, Left: &UnaryNode{Operator: "+", Node: &IntegerNode{}}, - Right: &UnaryNode{Operator: "-", - Node: &IntegerNode{}}}, + Comparators: []Node{&UnaryNode{Operator: "-", + Node: &IntegerNode{}}}}, }, { "[a, b, c]", @@ -300,53 +300,50 @@ world`}, }, { `foo matches "foo"`, - &BinaryNode{ - Operator: "matches", - Left: &IdentifierNode{Value: "foo"}, - Right: &StringNode{Value: "foo"}}, + &CompareNode{ + Operators: []string{"matches"}, + Left: &IdentifierNode{Value: "foo"}, + Comparators: []Node{&StringNode{Value: "foo"}}}, }, { `foo not matches "foo"`, - &UnaryNode{ - Operator: "not", - Node: &BinaryNode{ - Operator: "matches", - Left: &IdentifierNode{Value: "foo"}, - Right: &StringNode{Value: "foo"}}}, + &CompareNode{ + Operators: []string{"not", "matches"}, + Left: &IdentifierNode{Value: "foo"}, + Comparators: []Node{&StringNode{Value: "foo"}}}, }, { `foo matches regex`, - &BinaryNode{ - Operator: "matches", - Left: &IdentifierNode{Value: "foo"}, - Right: &IdentifierNode{Value: "regex"}}, + &CompareNode{ + Operators: []string{"matches"}, + Left: &IdentifierNode{Value: "foo"}, + Comparators: []Node{&IdentifierNode{Value: "regex"}}}, }, { `foo contains "foo"`, - &BinaryNode{ - Operator: "contains", - Left: &IdentifierNode{Value: "foo"}, - Right: &StringNode{Value: "foo"}}, + &CompareNode{ + Operators: []string{"contains"}, + Left: &IdentifierNode{Value: "foo"}, + Comparators: []Node{&StringNode{Value: "foo"}}}, }, { `foo not contains "foo"`, - &UnaryNode{ - Operator: "not", - Node: &BinaryNode{Operator: "contains", - Left: &IdentifierNode{Value: "foo"}, - Right: &StringNode{Value: "foo"}}}, + &CompareNode{ + Operators: []string{"not", "contains"}, + Left: &IdentifierNode{Value: "foo"}, + Comparators: []Node{&StringNode{Value: "foo"}}}, }, { `foo startsWith "foo"`, - &BinaryNode{Operator: "startsWith", - Left: &IdentifierNode{Value: "foo"}, - Right: &StringNode{Value: "foo"}}, + &CompareNode{Operators: []string{"startsWith"}, + Left: &IdentifierNode{Value: "foo"}, + Comparators: []Node{&StringNode{Value: "foo"}}}, }, { `foo endsWith "foo"`, - &BinaryNode{Operator: "endsWith", - Left: &IdentifierNode{Value: "foo"}, - Right: &StringNode{Value: "foo"}}, + &CompareNode{Operators: []string{"endsWith"}, + Left: &IdentifierNode{Value: "foo"}, + Comparators: []Node{&StringNode{Value: "foo"}}}, }, { "1..9", @@ -356,9 +353,9 @@ world`}, }, { "0 in []", - &BinaryNode{Operator: "in", - Left: &IntegerNode{}, - Right: &ArrayNode{Nodes: []Node{}}}, + &CompareNode{Operators: []string{"in"}, + Left: &IntegerNode{}, + Comparators: []Node{&ArrayNode{Nodes: []Node{}}}}, }, { "not in_var", @@ -367,59 +364,54 @@ world`}, }, { "-1 not in [1, 2, 3, 4]", - &UnaryNode{Operator: "not", - Node: &BinaryNode{Operator: "in", - Left: &UnaryNode{Operator: "-", Node: &IntegerNode{Value: 1}}, - Right: &ArrayNode{Nodes: []Node{ - &IntegerNode{Value: 1}, - &IntegerNode{Value: 2}, - &IntegerNode{Value: 3}, - &IntegerNode{Value: 4}, - }}}}, + &CompareNode{Operators: []string{"not", "in"}, + Left: &UnaryNode{Operator: "-", Node: &IntegerNode{Value: 1}}, + Comparators: []Node{&ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, + }}}}, }, { "1*8 not in [1, 2, 3, 4]", - &UnaryNode{Operator: "not", - Node: &BinaryNode{Operator: "in", - Left: &BinaryNode{Operator: "*", - Left: &IntegerNode{Value: 1}, - Right: &IntegerNode{Value: 8}, - }, - Right: &ArrayNode{Nodes: []Node{ - &IntegerNode{Value: 1}, - &IntegerNode{Value: 2}, - &IntegerNode{Value: 3}, - &IntegerNode{Value: 4}, - }}}}, + &CompareNode{Operators: []string{"not", "in"}, + Left: &BinaryNode{Operator: "*", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 8}, + }, + Comparators: []Node{&ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, + }}}}, }, { "2==2 ? false : 3 not in [1, 2, 5]", &ConditionalNode{ - Cond: &BinaryNode{ - Operator: "==", - Left: &IntegerNode{Value: 2}, - Right: &IntegerNode{Value: 2}, + Cond: &CompareNode{ + Operators: []string{"=="}, + Left: &IntegerNode{Value: 2}, + Comparators: []Node{&IntegerNode{Value: 2}}, }, Exp1: &BoolNode{Value: false}, - Exp2: &UnaryNode{ - Operator: "not", - Node: &BinaryNode{ - Operator: "in", - Left: &IntegerNode{Value: 3}, - Right: &ArrayNode{Nodes: []Node{ - &IntegerNode{Value: 1}, - &IntegerNode{Value: 2}, - &IntegerNode{Value: 5}, - }}}}}, + Exp2: &CompareNode{ + Operators: []string{"not", "in"}, + Left: &IntegerNode{Value: 3}, + Comparators: []Node{&ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 5}, + }}}}}, }, { "'foo' + 'bar' not matches 'foobar'", - &UnaryNode{Operator: "not", - Node: &BinaryNode{Operator: "matches", - Left: &BinaryNode{Operator: "+", - Left: &StringNode{Value: "foo"}, - Right: &StringNode{Value: "bar"}}, - Right: &StringNode{Value: "foobar"}}}, + &CompareNode{Operators: []string{"not", "matches"}, + Left: &BinaryNode{Operator: "+", + Left: &StringNode{Value: "foo"}, + Right: &StringNode{Value: "bar"}}, + Comparators: []Node{&StringNode{Value: "foobar"}}}, }, { "all(Tickets, #)", @@ -438,11 +430,11 @@ world`}, Arguments: []Node{ &IdentifierNode{Value: "Tickets"}, &ClosureNode{ - Node: &BinaryNode{ - Operator: ">", + Node: &CompareNode{ + Operators: []string{">"}, Left: &MemberNode{Node: &PointerNode{}, Property: &StringNode{Value: "Price"}}, - Right: &IntegerNode{Value: 0}}}}}, + Comparators: []Node{&IntegerNode{Value: 0}}}}}}, }, { "one(Tickets, {#.Price > 0})", @@ -451,21 +443,21 @@ world`}, Arguments: []Node{ &IdentifierNode{Value: "Tickets"}, &ClosureNode{ - Node: &BinaryNode{ - Operator: ">", + Node: &CompareNode{ + Operators: []string{">"}, Left: &MemberNode{ Node: &PointerNode{}, Property: &StringNode{Value: "Price"}, }, - Right: &IntegerNode{Value: 0}}}}}, + Comparators: []Node{&IntegerNode{Value: 0}}}}}}, }, { "filter(Prices, {# > 100})", &BuiltinNode{Name: "filter", Arguments: []Node{&IdentifierNode{Value: "Prices"}, - &ClosureNode{Node: &BinaryNode{Operator: ">", - Left: &PointerNode{}, - Right: &IntegerNode{Value: 100}}}}}, + &ClosureNode{Node: &CompareNode{Operators: []string{">"}, + Left: &PointerNode{}, + Comparators: []Node{&IntegerNode{Value: 100}}}}}}, }, { "array[1:2]", @@ -589,62 +581,49 @@ world`}, }, { `1 < 2 > 3`, - &BinaryNode{ - Operator: "&&", - Left: &BinaryNode{ - Operator: "<", - Left: &IntegerNode{Value: 1}, - Right: &IntegerNode{Value: 2}, - }, - Right: &BinaryNode{ - Operator: ">", - Left: &IntegerNode{Value: 2}, - Right: &IntegerNode{Value: 3}, + &CompareNode{ + Operators: []string{"<", ">"}, + Left: &IntegerNode{Value: 1}, + Comparators: []Node{ + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, }, }, }, { `1 < 2 < 3 < 4`, - &BinaryNode{ - Operator: "&&", - Left: &BinaryNode{ - Operator: "&&", - Left: &BinaryNode{ - Operator: "<", - Left: &IntegerNode{Value: 1}, - Right: &IntegerNode{Value: 2}, - }, - Right: &BinaryNode{ - Operator: "<", - Left: &IntegerNode{Value: 2}, - Right: &IntegerNode{Value: 3}, - }, - }, - Right: &BinaryNode{ - Operator: "<", - Left: &IntegerNode{Value: 3}, - Right: &IntegerNode{Value: 4}, + &CompareNode{ + Operators: []string{"<", "<", "<"}, + Left: &IntegerNode{Value: 1}, + Comparators: []Node{ + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, }, }, }, { `1 < 2 < 3 == true`, - &BinaryNode{ - Operator: "==", - Left: &BinaryNode{ - Operator: "&&", - Left: &BinaryNode{ - Operator: "<", - Left: &IntegerNode{Value: 1}, - Right: &IntegerNode{Value: 2}, - }, - Right: &BinaryNode{ - Operator: "<", - Left: &IntegerNode{Value: 2}, - Right: &IntegerNode{Value: 3}, - }, + &CompareNode{ + Operators: []string{"<", "<", "=="}, + Left: &IntegerNode{Value: 1}, + Comparators: []Node{ + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &BoolNode{Value: true}, + }, + }, + }, + { + `1 < 2 < 3 not contains [1, 2]`, + &CompareNode{ + Operators: []string{"<", "<", "not", "contains"}, + Left: &IntegerNode{Value: 1}, + Comparators: []Node{ + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &ArrayNode{Nodes: []Node{&IntegerNode{Value: 1}, &IntegerNode{Value: 2}}}, }, - Right: &BoolNode{Value: true}, }, }, } diff --git a/patcher/operator_override.go b/patcher/operator_override.go index 551fe09bb..2d5064e36 100644 --- a/patcher/operator_override.go +++ b/patcher/operator_override.go @@ -18,27 +18,57 @@ type OperatorOverloading struct { } func (p *OperatorOverloading) Visit(node *ast.Node) { - binaryNode, ok := (*node).(*ast.BinaryNode) - if !ok { - return - } - - if binaryNode.Operator != p.Operator { - return - } - - leftType := binaryNode.Left.Type() - rightType := binaryNode.Right.Type() - - ret, fn, ok := p.FindSuitableOperatorOverload(leftType, rightType) - if ok { - newNode := &ast.CallNode{ - Callee: &ast.IdentifierNode{Value: fn}, - Arguments: []ast.Node{binaryNode.Left, binaryNode.Right}, + switch n := (*node).(type) { + case *ast.BinaryNode: + if n.Operator != p.Operator { + return + } + if ret, fn, ok := p.FindSuitableOperatorOverload(n.Left.Type(), n.Right.Type()); ok { + newNode := &ast.CallNode{ + Callee: &ast.IdentifierNode{Value: fn}, + Arguments: []ast.Node{n.Left, n.Right}, + } + newNode.SetType(ret) + ast.Patch(node, newNode) + p.applied = true + } + case *ast.CompareNode: + nodeLeft := n.Left + opIdx := 0 + for i, comparator := range n.Comparators { + op := n.Operators[opIdx] + negate := op == "not" + if negate { + opIdx++ + op = n.Operators[opIdx] + } + if op != "&&" && op == p.Operator { + if ret, fn, ok := p.FindSuitableOperatorOverload(nodeLeft.Type(), comparator.Type()); ok { + var newNode ast.Node = &ast.CallNode{ + Callee: &ast.IdentifierNode{Value: fn}, + Arguments: []ast.Node{nodeLeft, comparator}, + } + newNode.SetType(ret) + if negate { + newNode = &ast.UnaryNode{ + Operator: "not", + Node: newNode, + } + n.Operators = append(n.Operators[:opIdx], n.Operators[opIdx+1:]...) + opIdx-- + } + n.Operators[opIdx] = "&&" + if i == 0 { + n.Left = newNode + } else { + n.Comparators[i] = newNode + } + p.applied = true + } + } + nodeLeft = comparator + opIdx++ } - newNode.SetType(ret) - ast.Patch(node, newNode) - p.applied = true } } @@ -92,9 +122,8 @@ func (p *OperatorOverloading) findSuitableOperatorOverloadInFunctions(l, r refle func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) { firstArgType := t.In(firstInIndex) secondArgType := t.In(firstInIndex + 1) - - firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType))) - secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType))) + firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType))) || (l.Kind() == reflect.Interface && firstArgType.AssignableTo(l)) + secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType))) || (r.Kind() == reflect.Interface && secondArgType.AssignableTo(r)) if firstArgumentFit && secondArgumentFit { return t.Out(0), true }