Skip to content

Enable comparison feature like python #664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,11 @@ type PairNode struct {
Key Node // Key of the pair.
Value Node // Value of the pair.
}

// CompareNode represents comparison
type CompareNode struct {
base
Left Node // Left represents the left-hand side of the comparison operation
Operators []string // Operators is a list of comparison operator tokens used in the comparison.
Comparators []Node // Comparators representing the right-hand sides of the comparison operation
}
31 changes: 31 additions & 0 deletions ast/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,34 @@ func (n *PairNode) String() string {
}
return fmt.Sprintf("(%s): %s", n.Key.String(), n.Value.String())
}

func (n *CompareNode) string(node Node) string {
switch v := node.(type) {
case *BinaryNode, *CompareNode:
return fmt.Sprintf("(%s)", v)
default:
return v.String()
}
}

func (n *CompareNode) String() string {
var builder strings.Builder
builder.WriteString(n.string(n.Left))
opIdx := 0
for i := 0; i < len(n.Comparators); i++ {
if op := n.Operators[opIdx]; op != "&&" {
builder.WriteByte(' ')
builder.WriteString(op)
if op == "not" {
opIdx++
builder.WriteByte(' ')
builder.WriteString(n.Operators[opIdx])
}
builder.WriteByte(' ')
builder.WriteString(n.string(n.Comparators[i]))
}
opIdx++
}

return builder.String()
}
2 changes: 1 addition & 1 deletion ast/print_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestPrint(t *testing.T) {
{`a == b`, `a == b`},
{`a matches b`, `a matches b`},
{`a in b`, `a in b`},
{`a not in b`, `not (a in b)`},
{`a not in b`, `a not in b`},
{`a and b`, `a and b`},
{`a or b`, `a or b`},
{`a or b and c`, `a or (b and c)`},
Expand Down
5 changes: 5 additions & 0 deletions ast/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ func Walk(node *Node, v Visitor) {
case *PairNode:
Walk(&n.Key, v)
Walk(&n.Value, v)
case *CompareNode:
Walk(&n.Left, v)
for i := range n.Comparators {
Walk(&n.Comparators[i], v)
}
default:
panic(fmt.Sprintf("undefined node type (%T)", node))
}
Expand Down
150 changes: 86 additions & 64 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ func (v *checker) visit(node ast.Node) Nature {
nt = v.MapNode(n)
case *ast.PairNode:
nt = v.PairNode(n)
case *ast.CompareNode:
nt = v.CompareNode(n)
default:
panic(fmt.Sprintf("undefined node type (%T)", node))
}
Expand Down Expand Up @@ -282,11 +284,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature {
r = r.Deref()

switch node.Operator {
case "==", "!=":
if isComparable(l, r) {
return boolNature
}

case "or", "||", "and", "&&":
if isBool(l) && isBool(r) {
return boolNature
Expand All @@ -295,20 +292,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature {
return boolNature
}

case "<", ">", ">=", "<=":
if isNumber(l) && isNumber(r) {
return boolNature
}
if isString(l) && isString(r) {
return boolNature
}
if isTime(l) && isTime(r) {
return boolNature
}
if or(l, r, isNumber, isString, isTime) {
return boolNature
}

case "-":
if isNumber(l) && isNumber(r) {
return combined(l, r)
Expand Down Expand Up @@ -372,51 +355,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) Nature {
return unknown
}

case "in":
if (isString(l) || isUnknown(l)) && isStruct(r) {
return boolNature
}
if isMap(r) {
if !isUnknown(l) && !l.AssignableTo(r.Key()) {
return v.error(node, "cannot use %v as type %v in map key", l, r.Key())
}
return boolNature
}
if isArray(r) {
if !isComparable(l, r.Elem()) {
return v.error(node, "cannot use %v as type %v in array", l, r.Elem())
}
return boolNature
}
if isUnknown(l) && anyOf(r, isString, isArray, isMap) {
return boolNature
}
if isUnknown(r) {
return boolNature
}

case "matches":
if s, ok := node.Right.(*ast.StringNode); ok {
_, err := regexp.Compile(s.Value)
if err != nil {
return v.error(node, err.Error())
}
}
if isString(l) && isString(r) {
return boolNature
}
if or(l, r, isString) {
return boolNature
}

case "contains", "startsWith", "endsWith":
if isString(l) && isString(r) {
return boolNature
}
if or(l, r, isString) {
return boolNature
}

case "..":
if isInteger(l) && isInteger(r) {
return Nature{
Expand Down Expand Up @@ -1230,3 +1168,87 @@ func (v *checker) PairNode(node *ast.PairNode) Nature {
v.visit(node.Value)
return nilNature
}

func (v *checker) CompareNode(node *ast.CompareNode) Nature {
nodeLeft := node.Left
opIdx := 0
operatorOverride := false
for i, comparator := range node.Comparators {
op := node.Operators[opIdx]
if negate := op == "not"; negate {
opIdx++
op = node.Operators[opIdx]
}
if op == "&&" {
if !operatorOverride {
operatorOverride = true
}
} else if err := v.compareNode(op, nodeLeft, comparator, i); err != nil {
return v.error(comparator, err.Error())
}
opIdx++
nodeLeft = comparator
}
if operatorOverride {
return unknown
}
return boolNature
}

func (v *checker) compareNode(op string, nodeLeft, nodeRight ast.Node, index int) error {
l := v.visit(nodeLeft)
r := v.visit(nodeRight)

l = l.Deref()
r = r.Deref()
switch op {
case "==", "!=":
if (isBool(r) && index > 0) || isComparable(l, r) {
return nil
}
case "<", ">", ">=", "<=":
if isNumber(l) && isNumber(r) ||
isString(l) && isString(r) ||
isTime(l) && isTime(r) ||
or(l, r, isNumber, isString, isTime) {
return nil
}
case "in":
if (isString(l) || isUnknown(l)) && isStruct(r) {
return nil
}
if isMap(r) {
if !isUnknown(l) && !l.AssignableTo(r.Key()) {
return fmt.Errorf("cannot use %v as type %v in map key", l, r.Key())
}
return nil
}
if isArray(r) {
if !isComparable(l, r.Elem()) {
return fmt.Errorf("cannot use %v as type %v in array", l, r.Elem())
}
return nil
}
if (isUnknown(l) && anyOf(r, isString, isArray, isMap)) || isUnknown(r) {
return nil
}

case "matches":
if s, ok := nodeRight.(*ast.StringNode); ok {
if _, err := regexp.Compile(s.Value); err != nil {
return err
}
}
if (isString(l) && isString(r)) || or(l, r, isString) {
return nil
}
case "contains", "startsWith", "endsWith":
if isString(l) && isString(r) ||
or(l, r, isString) {
return nil
}
default:
return fmt.Errorf("unknown operator (%v)", op)
}
return fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, op, l, r)
}
Loading
Loading