From ff601c8d6e47fba200379fb81e42f29ee61bdc6f Mon Sep 17 00:00:00 2001 From: Andrei Vydrin Date: Sat, 4 Dec 2021 16:04:53 +0700 Subject: [PATCH] fix: rework nilsafe logic with subProperty ops --- ast/node.go | 3 +-- checker/checker.go | 11 ++++------- compiler/compiler.go | 2 -- expr_test.go | 12 ++++++------ parser/parser.go | 13 ++++--------- parser/parser_test.go | 10 +++++++++- vm/opcodes.go | 1 - vm/program.go | 3 --- vm/vm.go | 3 --- 9 files changed, 24 insertions(+), 34 deletions(-) diff --git a/ast/node.go b/ast/node.go index 4b2b5c277..2b9192183 100644 --- a/ast/node.go +++ b/ast/node.go @@ -48,8 +48,7 @@ type NilNode struct { type IdentifierNode struct { base - Value string - NilSafe bool + Value string } type IntegerNode struct { diff --git a/checker/checker.go b/checker/checker.go index 282031a1f..6f55feb64 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -140,10 +140,7 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) reflect.Type { } return interfaceType } - if !node.NilSafe { - return v.error(node, "unknown name %v", node.Value) - } - return nilType + return v.error(node, "unknown name %v", node.Value) } func (v *visitor) IntegerNode(*ast.IntegerNode) reflect.Type { @@ -348,9 +345,9 @@ func (v *visitor) FunctionNode(node *ast.FunctionNode) reflect.Type { fn.NumIn() == inputParamsCount && ((fn.NumOut() == 1 && // Function with one return value fn.Out(0).Kind() == reflect.Interface) || - (fn.NumOut() == 2 && // Function with one return value and an error - fn.Out(0).Kind() == reflect.Interface && - fn.Out(1) == errorType)) { + (fn.NumOut() == 2 && // Function with one return value and an error + fn.Out(0).Kind() == reflect.Interface && + fn.Out(1) == errorType)) { rest := fn.In(fn.NumIn() - 1) // function has only one param for functions and two for methods if rest.Kind() == reflect.Slice && rest.Elem().Kind() == reflect.Interface { node.Fast = true diff --git a/compiler/compiler.go b/compiler/compiler.go index 36ac92f23..bd3ef61b2 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -180,8 +180,6 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) { v := c.makeConstant(node.Value) if c.mapEnv { c.emit(OpFetchMap, v...) - } else if node.NilSafe { - c.emit(OpFetchNilSafe, v...) } else { c.emit(OpFetch, v...) } diff --git a/expr_test.go b/expr_test.go index 13cb55c8e..2ec063a9f 100644 --- a/expr_test.go +++ b/expr_test.go @@ -946,6 +946,7 @@ func TestExpr_map_default_values(t *testing.T) { func TestExpr_nil_safe(t *testing.T) { env := map[string]interface{}{ + "foo": struct{}{}, "bar": map[string]*string{}, } @@ -961,21 +962,19 @@ func TestExpr_nil_safe(t *testing.T) { func TestExpr_nil_safe_first_ident(t *testing.T) { env := map[string]interface{}{ + "foo": struct{}{}, "bar": map[string]*string{}, } input := `foo?.missing.test == '' && bar['missing'] == nil` - program, err := expr.Compile(input, expr.Env(env)) - require.NoError(t, err) - - output, err := expr.Run(program, env) - require.NoError(t, err) - require.Equal(t, false, output) + _, err := expr.Compile(input, expr.Env(env)) + require.Error(t, err) } func TestExpr_nil_safe_not_strict(t *testing.T) { env := map[string]interface{}{ + "foo": struct{}{}, "bar": map[string]*string{}, } @@ -1011,6 +1010,7 @@ func TestExpr_nil_safe_valid_value(t *testing.T) { func TestExpr_nil_safe_method(t *testing.T) { env := map[string]interface{}{ + "foo": struct{}{}, "bar": map[string]*string{}, } diff --git a/parser/parser.go b/parser/parser.go index 821de9d35..6d34f7936 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -283,7 +283,7 @@ func (p *parser) parsePrimaryExpression() Node { node.SetLocation(token.Location) return node default: - node = p.parseIdentifierExpression(token, p.current) + node = p.parseIdentifierExpression(token) } case Number: @@ -334,7 +334,7 @@ func (p *parser) parsePrimaryExpression() Node { return p.parsePostfixExpression(node) } -func (p *parser) parseIdentifierExpression(token, next Token) Node { +func (p *parser) parseIdentifierExpression(token Token) Node { var node Node if p.current.Is(Bracket, "(") { var arguments []Node @@ -367,11 +367,7 @@ func (p *parser) parseIdentifierExpression(token, next Token) Node { node.SetLocation(token.Location) } } else { - var nilsafe bool - if next.Value == "?." { - nilsafe = true - } - node = &IdentifierNode{Value: token.Value, NilSafe: nilsafe} + node = &IdentifierNode{Value: token.Value} node.SetLocation(token.Location) } return node @@ -464,14 +460,13 @@ end: func (p *parser) parsePostfixExpression(node Node) Node { token := p.current - var nilsafe bool for (token.Is(Operator) || token.Is(Bracket)) && p.err == nil { if token.Value == "." || token.Value == "?." { + var nilsafe bool if token.Value == "?." { nilsafe = true } p.next() - token = p.current p.next() diff --git a/parser/parser_test.go b/parser/parser_test.go index 989829b37..60b2db789 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -103,7 +103,7 @@ func TestParse(t *testing.T) { }, { "foo?.bar", - &ast.PropertyNode{Node: &ast.IdentifierNode{Value: "foo", NilSafe: true}, Property: "bar", NilSafe: true}, + &ast.PropertyNode{Node: &ast.IdentifierNode{Value: "foo"}, Property: "bar", NilSafe: true}, }, { "foo['all']", @@ -238,6 +238,14 @@ func TestParse(t *testing.T) { "[1, 2, 3,]", &ast.ArrayNode{Nodes: []ast.Node{&ast.IntegerNode{Value: 1}, &ast.IntegerNode{Value: 2}, &ast.IntegerNode{Value: 3}}}, }, + { + "a?.b.c", + &ast.PropertyNode{Node: &ast.PropertyNode{Node: &ast.IdentifierNode{Value: "a"}, Property: "b", NilSafe: true}, Property: "c", NilSafe: false}, + }, + { + "a?.b?.c", + &ast.PropertyNode{Node: &ast.PropertyNode{Node: &ast.IdentifierNode{Value: "a"}, Property: "b", NilSafe: true}, Property: "c", NilSafe: true}, + }, } for _, test := range parseTests { actual, err := parser.Parse(test.input) diff --git a/vm/opcodes.go b/vm/opcodes.go index 7f2dd37e9..1bb4db2ac 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -5,7 +5,6 @@ const ( OpPop OpRot OpFetch - OpFetchNilSafe OpFetchMap OpTrue OpFalse diff --git a/vm/program.go b/vm/program.go index 5a41f8af4..d26a90abb 100644 --- a/vm/program.go +++ b/vm/program.go @@ -73,9 +73,6 @@ func (program *Program) Disassemble() string { case OpFetch: constant("OpFetch") - case OpFetchNilSafe: - constant("OpFetchNilSafe") - case OpFetchMap: constant("OpFetchMap") diff --git a/vm/vm.go b/vm/vm.go index 6957dfa64..94762a5b4 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -102,9 +102,6 @@ func (vm *VM) Run(program *Program, env interface{}) (out interface{}, err error case OpFetch: vm.push(fetch(env, vm.constant(), false)) - case OpFetchNilSafe: - vm.push(fetch(env, vm.constant(), true)) - case OpFetchMap: vm.push(env.(map[string]interface{})[vm.constant().(string)])