Skip to content

Commit baae791

Browse files
authored
Add nil safe operator (#173)
1 parent ea62436 commit baae791

File tree

13 files changed

+286
-21
lines changed

13 files changed

+286
-21
lines changed

ast/node.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ type NilNode struct {
4848

4949
type IdentifierNode struct {
5050
base
51-
Value string
51+
Value string
52+
NilSafe bool
5253
}
5354

5455
type IntegerNode struct {
@@ -100,6 +101,7 @@ type PropertyNode struct {
100101
base
101102
Node Node
102103
Property string
104+
NilSafe bool
103105
}
104106

105107
type IndexNode struct {
@@ -120,6 +122,7 @@ type MethodNode struct {
120122
Node Node
121123
Method string
122124
Arguments []Node
125+
NilSafe bool
123126
}
124127

125128
type FunctionNode struct {

checker/checker.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ func (v *visitor) IdentifierNode(node *ast.IdentifierNode) reflect.Type {
136136
}
137137
return interfaceType
138138
}
139-
return v.error(node, "unknown name %v", node.Value)
139+
if !node.NilSafe {
140+
return v.error(node, "unknown name %v", node.Value)
141+
}
142+
return nilType
140143
}
141144

142145
func (v *visitor) IntegerNode(*ast.IntegerNode) reflect.Type {
@@ -276,12 +279,13 @@ func (v *visitor) MatchesNode(node *ast.MatchesNode) reflect.Type {
276279

277280
func (v *visitor) PropertyNode(node *ast.PropertyNode) reflect.Type {
278281
t := v.visit(node.Node)
279-
280282
if t, ok := fieldType(t, node.Property); ok {
281283
return t
282284
}
283-
284-
return v.error(node, "type %v has no field %v", t, node.Property)
285+
if !node.NilSafe {
286+
return v.error(node, "type %v has no field %v", t, node.Property)
287+
}
288+
return nil
285289
}
286290

287291
func (v *visitor) IndexNode(node *ast.IndexNode) reflect.Type {
@@ -361,7 +365,10 @@ func (v *visitor) MethodNode(node *ast.MethodNode) reflect.Type {
361365
return v.checkFunc(fn, method, node, node.Method, node.Arguments)
362366
}
363367
}
364-
return v.error(node, "type %v has no method %v", t, node.Method)
368+
if !node.NilSafe {
369+
return v.error(node, "type %v has no method %v", t, node.Method)
370+
}
371+
return nil
365372
}
366373

367374
// checkFunc checks func arguments and returns "return type" of func or method.

cmd/exe/dot.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ func (v *visitor) Exit(ref *Node) {
8787

8888
case *PropertyNode:
8989
a := v.pop()
90-
v.push(fmt.Sprintf(".%v", node.Property))
90+
if !node.NilSafe {
91+
v.push(fmt.Sprintf(".%v", node.Property))
92+
} else {
93+
v.push(fmt.Sprintf("?.%v", node.Property))
94+
}
9195
v.link(a)
9296

9397
case *IndexNode:
@@ -103,7 +107,11 @@ func (v *visitor) Exit(ref *Node) {
103107
args = append(args, v.pop())
104108
}
105109
a := v.pop()
106-
v.push(fmt.Sprintf(".%v(...)", node.Method))
110+
if !node.NilSafe {
111+
v.push(fmt.Sprintf(".%v(...)", node.Method))
112+
} else {
113+
v.push(fmt.Sprintf("?.%v(...)", node.Method))
114+
}
107115
v.link(a)
108116
for i := len(args) - 1; i >= 0; i-- {
109117
v.link(args[i])

compiler/compiler.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) {
180180
v := c.makeConstant(node.Value)
181181
if c.mapEnv {
182182
c.emit(OpFetchMap, v...)
183+
} else if node.NilSafe {
184+
c.emit(OpFetchNilSafe, v...)
183185
} else {
184186
c.emit(OpFetch, v...)
185187
}
@@ -401,7 +403,11 @@ func (c *compiler) MatchesNode(node *ast.MatchesNode) {
401403

402404
func (c *compiler) PropertyNode(node *ast.PropertyNode) {
403405
c.compile(node.Node)
404-
c.emit(OpProperty, c.makeConstant(node.Property)...)
406+
if !node.NilSafe {
407+
c.emit(OpProperty, c.makeConstant(node.Property)...)
408+
} else {
409+
c.emit(OpPropertyNilSafe, c.makeConstant(node.Property)...)
410+
}
405411
}
406412

407413
func (c *compiler) IndexNode(node *ast.IndexNode) {
@@ -430,7 +436,11 @@ func (c *compiler) MethodNode(node *ast.MethodNode) {
430436
for _, arg := range node.Arguments {
431437
c.compile(arg)
432438
}
433-
c.emit(OpMethod, c.makeConstant(Call{Name: node.Method, Size: len(node.Arguments)})...)
439+
if !node.NilSafe {
440+
c.emit(OpMethod, c.makeConstant(Call{Name: node.Method, Size: len(node.Arguments)})...)
441+
} else {
442+
c.emit(OpMethodNilSafe, c.makeConstant(Call{Name: node.Method, Size: len(node.Arguments)})...)
443+
}
434444
}
435445

436446
func (c *compiler) FunctionNode(node *ast.FunctionNode) {

expr_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,151 @@ func TestExpr_map_default_values(t *testing.T) {
944944
require.Equal(t, true, output)
945945
}
946946

947+
func TestExpr_nil_safe(t *testing.T) {
948+
env := map[string]interface{}{
949+
"bar": map[string]*string{},
950+
}
951+
952+
input := `foo?.missing?.test == '' && bar['missing'] == nil`
953+
954+
program, err := expr.Compile(input, expr.Env(env))
955+
require.NoError(t, err)
956+
957+
output, err := expr.Run(program, env)
958+
require.NoError(t, err)
959+
require.Equal(t, false, output)
960+
}
961+
962+
func TestExpr_nil_safe_first_ident(t *testing.T) {
963+
env := map[string]interface{}{
964+
"bar": map[string]*string{},
965+
}
966+
967+
input := `foo?.missing.test == '' && bar['missing'] == nil`
968+
969+
program, err := expr.Compile(input, expr.Env(env))
970+
require.NoError(t, err)
971+
972+
output, err := expr.Run(program, env)
973+
require.NoError(t, err)
974+
require.Equal(t, false, output)
975+
}
976+
977+
func TestExpr_nil_safe_not_strict(t *testing.T) {
978+
env := map[string]interface{}{
979+
"bar": map[string]*string{},
980+
}
981+
982+
input := `foo?.missing?.test == '' && bar['missing'] == nil`
983+
984+
program, err := expr.Compile(input)
985+
require.NoError(t, err)
986+
987+
output, err := expr.Run(program, env)
988+
require.NoError(t, err)
989+
require.Equal(t, false, output)
990+
}
991+
992+
func TestExpr_nil_safe_valid_value(t *testing.T) {
993+
env := map[string]interface{}{
994+
"foo": map[string]map[string]interface{}{
995+
"missing": {
996+
"test": "hello",
997+
},
998+
},
999+
"bar": map[string]*string{},
1000+
}
1001+
1002+
input := `foo?.missing?.test == 'hello' && bar['missing'] == nil`
1003+
1004+
program, err := expr.Compile(input, expr.Env(env))
1005+
require.NoError(t, err)
1006+
1007+
output, err := expr.Run(program, env)
1008+
require.NoError(t, err)
1009+
require.Equal(t, true, output)
1010+
}
1011+
1012+
func TestExpr_nil_safe_method(t *testing.T) {
1013+
env := map[string]interface{}{
1014+
"bar": map[string]*string{},
1015+
}
1016+
1017+
input := `foo?.missing?.test() == '' && bar['missing'] == nil`
1018+
1019+
program, err := expr.Compile(input, expr.Env(env))
1020+
require.NoError(t, err)
1021+
1022+
output, err := expr.Run(program, env)
1023+
require.NoError(t, err)
1024+
require.Equal(t, false, output)
1025+
}
1026+
1027+
func TestExpr_nil_safe_struct(t *testing.T) {
1028+
type P struct {
1029+
Test string
1030+
}
1031+
type Env struct {
1032+
Foo struct {
1033+
Missing *P
1034+
}
1035+
Bar struct {
1036+
Missing *P
1037+
}
1038+
}
1039+
env := Env{
1040+
Bar: struct {
1041+
Missing *P
1042+
}{
1043+
Missing: nil,
1044+
},
1045+
}
1046+
input := `Foo?.Missing?.Test == '' && Bar.Missing == nil`
1047+
1048+
program, err := expr.Compile(input)
1049+
require.NoError(t, err)
1050+
1051+
output, err := expr.Run(program, env)
1052+
require.NoError(t, err)
1053+
require.Equal(t, false, output)
1054+
}
1055+
1056+
func TestExpr_nil_safe_struct_valid(t *testing.T) {
1057+
type P struct {
1058+
Test string
1059+
}
1060+
type Env struct {
1061+
Foo struct {
1062+
Missing *P
1063+
}
1064+
Bar struct {
1065+
Missing *P
1066+
}
1067+
}
1068+
env := Env{
1069+
Foo: struct {
1070+
Missing *P
1071+
}{
1072+
Missing: &P{
1073+
Test: "hello",
1074+
},
1075+
},
1076+
Bar: struct {
1077+
Missing *P
1078+
}{
1079+
Missing: nil,
1080+
},
1081+
}
1082+
input := `Foo?.Missing?.Test == 'hello' && Bar.Missing == nil`
1083+
1084+
program, err := expr.Compile(input)
1085+
require.NoError(t, err)
1086+
1087+
output, err := expr.Run(program, env)
1088+
require.NoError(t, err)
1089+
require.Equal(t, true, output)
1090+
}
1091+
9471092
func TestExpr_map_default_values_compile_check(t *testing.T) {
9481093
tests := []struct {
9491094
env interface{}

parser/lexer/lexer_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ var lexTests = []lexTest{
6363
{Kind: EOF},
6464
},
6565
},
66+
{
67+
"a and orb().val and foo?.bar",
68+
[]Token{
69+
{Kind: Identifier, Value: "a"},
70+
{Kind: Operator, Value: "and"},
71+
{Kind: Identifier, Value: "orb"},
72+
{Kind: Bracket, Value: "("},
73+
{Kind: Bracket, Value: ")"},
74+
{Kind: Operator, Value: "."},
75+
{Kind: Identifier, Value: "val"},
76+
{Kind: Operator, Value: "and"},
77+
{Kind: Identifier, Value: "foo"},
78+
{Kind: Operator, Value: "?."},
79+
{Kind: Identifier, Value: "bar"},
80+
{Kind: EOF},
81+
},
82+
},
6683
{
6784
`not in not abc not i not(false) not in not in`,
6885
[]Token{

parser/lexer/state.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ func root(l *lexer) stateFn {
2424
case '0' <= r && r <= '9':
2525
l.backup()
2626
return number
27+
case r == '?':
28+
if l.peek() == '.' {
29+
return nilsafe
30+
}
31+
l.emit(Operator)
2732
case strings.ContainsRune("([{", r):
2833
l.emit(Bracket)
2934
case strings.ContainsRune(")]}", r):
@@ -102,6 +107,13 @@ func dot(l *lexer) stateFn {
102107
return root
103108
}
104109

110+
func nilsafe(l *lexer) stateFn {
111+
l.next()
112+
l.accept("?.")
113+
l.emit(Operator)
114+
return root
115+
}
116+
105117
func identifier(l *lexer) stateFn {
106118
loop:
107119
for {

parser/parser.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ func (p *parser) parsePrimaryExpression() Node {
283283
node.SetLocation(token.Location)
284284
return node
285285
default:
286-
node = p.parseIdentifierExpression(token)
286+
node = p.parseIdentifierExpression(token, p.current)
287287
}
288288

289289
case Number:
@@ -334,7 +334,7 @@ func (p *parser) parsePrimaryExpression() Node {
334334
return p.parsePostfixExpression(node)
335335
}
336336

337-
func (p *parser) parseIdentifierExpression(token Token) Node {
337+
func (p *parser) parseIdentifierExpression(token, next Token) Node {
338338
var node Node
339339
if p.current.Is(Bracket, "(") {
340340
var arguments []Node
@@ -367,7 +367,11 @@ func (p *parser) parseIdentifierExpression(token Token) Node {
367367
node.SetLocation(token.Location)
368368
}
369369
} else {
370-
node = &IdentifierNode{Value: token.Value}
370+
var nilsafe bool
371+
if next.Value == "?." {
372+
nilsafe = true
373+
}
374+
node = &IdentifierNode{Value: token.Value, NilSafe: nilsafe}
371375
node.SetLocation(token.Location)
372376
}
373377
return node
@@ -460,8 +464,12 @@ end:
460464

461465
func (p *parser) parsePostfixExpression(node Node) Node {
462466
token := p.current
467+
var nilsafe bool
463468
for (token.Is(Operator) || token.Is(Bracket)) && p.err == nil {
464-
if token.Value == "." {
469+
if token.Value == "." || token.Value == "?." {
470+
if token.Value == "?." {
471+
nilsafe = true
472+
}
465473
p.next()
466474

467475
token = p.current
@@ -479,12 +487,14 @@ func (p *parser) parsePostfixExpression(node Node) Node {
479487
Node: node,
480488
Method: token.Value,
481489
Arguments: arguments,
490+
NilSafe: nilsafe,
482491
}
483492
node.SetLocation(token.Location)
484493
} else {
485494
node = &PropertyNode{
486495
Node: node,
487496
Property: token.Value,
497+
NilSafe: nilsafe,
488498
}
489499
node.SetLocation(token.Location)
490500
}
@@ -537,7 +547,6 @@ func (p *parser) parsePostfixExpression(node Node) Node {
537547
p.expect(Bracket, "]")
538548
}
539549
}
540-
541550
} else {
542551
break
543552
}

0 commit comments

Comments
 (0)