diff --git a/ast/node.go b/ast/node.go index f9d27fc6a..f70c49cc7 100644 --- a/ast/node.go +++ b/ast/node.go @@ -58,21 +58,7 @@ type NilNode struct { // IdentifierNode represents an identifier. type IdentifierNode struct { base - Value string // Name of the identifier. Like "foo" in "foo.bar". - FieldIndex []int // Internal. Index of the field in the list of fields. - Method bool // Internal. If true then the identifier is a method call. - MethodIndex int // Internal. Index of the method in the list of methods. -} - -// SetFieldIndex sets the field index of the identifier. -func (n *IdentifierNode) SetFieldIndex(field []int) { - n.FieldIndex = field -} - -// SetMethodIndex sets the method index of the identifier. -func (n *IdentifierNode) SetMethodIndex(methodIndex int) { - n.Method = true - n.MethodIndex = methodIndex + Value string // Name of the identifier. Like "foo" in "foo.bar". } // IntegerNode represents an integer. @@ -146,26 +132,9 @@ type ChainNode struct { // array[0] type MemberNode struct { base - Node Node // Node of the member access. Like "foo" in "foo.bar". - Property Node // Property of the member access. For property access it is a StringNode. - Optional bool // If true then the member access is optional. Like "foo?.bar". - Name string // Internal. Name of the filed or method. Used for error reporting. - FieldIndex []int // Internal. Index sequence of fields. Generated by type checker. - - // TODO: Combine Method and MethodIndex into a single MethodIndex field of &int type. - Method bool // Internal. If true then the member access is a method call. - MethodIndex int // Internal. Index of the method in the list of methods. Generated by type checker. -} - -// SetFieldIndex sets the field index of the member access. -func (n *MemberNode) SetFieldIndex(field []int) { - n.FieldIndex = field -} - -// SetMethodIndex sets the method index of the member access. -func (n *MemberNode) SetMethodIndex(methodIndex int) { - n.Method = true - n.MethodIndex = methodIndex + Node Node // Node of the member access. Like "foo" in "foo.bar". + Property Node // Property of the member access. For property access it is a StringNode. + Optional bool // If true then the member access is optional. Like "foo?.bar". } // SliceNode represents access to a slice of an array. diff --git a/checker/checker.go b/checker/checker.go index c214acf3b..38490bacf 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -166,23 +166,13 @@ func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info) return v.env(node, node.Value, true) } -type NodeWithIndexes interface { - ast.Node - SetFieldIndex(field []int) - SetMethodIndex(methodIndex int) -} - // env method returns type of environment variable. env only lookups for // environment variables, no builtins, no custom functions. -func (v *checker) env(node NodeWithIndexes, name string, strict bool) (reflect.Type, info) { +func (v *checker) env(node ast.Node, name string, strict bool) (reflect.Type, info) { if t, ok := v.config.Types[name]; ok { if t.Ambiguous { return v.error(node, "ambiguous identifier %v", name) } - node.SetFieldIndex(t.FieldIndex) - if t.Method { - node.SetMethodIndex(t.MethodIndex) - } return t.Type, info{method: t.Method} } if v.config.Strict && strict { @@ -477,8 +467,6 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { // the same interface. return m.Type, info{} } else { - node.SetMethodIndex(m.Index) - node.Name = name.Value return m.Type, info{method: true} } } @@ -508,8 +496,6 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value if field, ok := fetchField(base, propertyName); ok { - node.FieldIndex = field.Index - node.Name = propertyName return field.Type, info{} } if len(v.parents) > 1 { diff --git a/checker/info.go b/checker/info.go new file mode 100644 index 000000000..3245253a3 --- /dev/null +++ b/checker/info.go @@ -0,0 +1,50 @@ +package checker + +import ( + "reflect" + + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/conf" +) + +func FieldIndex(types conf.TypesTable, node ast.Node) (bool, []int, string) { + switch n := node.(type) { + case *ast.IdentifierNode: + if t, ok := types[n.Value]; ok && len(t.FieldIndex) > 0 { + return true, t.FieldIndex, n.Value + } + case *ast.MemberNode: + base := n.Node.Type() + if kind(base) == reflect.Ptr { + base = base.Elem() + } + if kind(base) == reflect.Struct { + if prop, ok := n.Property.(*ast.StringNode); ok { + name := prop.Value + if field, ok := fetchField(base, name); ok { + return true, field.Index, name + } + } + } + } + return false, nil, "" +} + +func MethodIndex(types conf.TypesTable, node ast.Node) (bool, int, string) { + switch n := node.(type) { + case *ast.IdentifierNode: + if t, ok := types[n.Value]; ok { + return t.Method, t.MethodIndex, n.Value + } + case *ast.MemberNode: + if name, ok := n.Property.(*ast.StringNode); ok { + base := n.Node.Type() + if base != nil && base.Kind() != reflect.Interface { + if m, ok := base.MethodByName(name.Value); ok { + return true, m.Index, name.Value + } + } + } + } + return false, 0, "" +} diff --git a/compiler/compiler.go b/compiler/compiler.go index 4831a3103..2b03e5afa 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -6,6 +6,7 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/checker" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/parser" @@ -34,6 +35,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro if config != nil { c.mapEnv = config.MapEnv c.cast = config.Expect + c.types = config.Types } c.compile(tree.Node) @@ -75,6 +77,7 @@ type compiler struct { nodes []ast.Node chains [][]int arguments []int + types conf.TypesTable } type scope struct { @@ -254,15 +257,15 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) { } if c.mapEnv { c.emit(OpLoadFast, c.addConstant(node.Value)) - } else if len(node.FieldIndex) > 0 { + } else if ok, index, name := checker.FieldIndex(c.types, node); ok { c.emit(OpLoadField, c.addConstant(&runtime.Field{ - Index: node.FieldIndex, - Path: []string{node.Value}, + Index: index, + Path: []string{name}, })) - } else if node.Method { + } else if ok, index, name := checker.MethodIndex(c.types, node); ok { c.emit(OpLoadMethod, c.addConstant(&runtime.Method{ - Name: node.Value, - Index: node.MethodIndex, + Name: name, + Index: index, })) } else { c.emit(OpLoadConst, c.addConstant(node.Value)) @@ -559,36 +562,43 @@ func (c *compiler) ChainNode(node *ast.ChainNode) { } func (c *compiler) MemberNode(node *ast.MemberNode) { - if node.Method { + if ok, index, name := checker.MethodIndex(c.types, node); ok { c.compile(node.Node) c.emit(OpMethod, c.addConstant(&runtime.Method{ - Name: node.Name, - Index: node.MethodIndex, + Name: name, + Index: index, })) return } op := OpFetch - index := node.FieldIndex - path := []string{node.Name} base := node.Node - if len(node.FieldIndex) > 0 { + + ok, index, nodeName := checker.FieldIndex(c.types, node) + path := []string{nodeName} + + if ok { op = OpFetchField for !node.Optional { - ident, ok := base.(*ast.IdentifierNode) - if ok && len(ident.FieldIndex) > 0 { - index = append(ident.FieldIndex, index...) - path = append([]string{ident.Value}, path...) - c.emitLocation(ident.Location(), OpLoadField, c.addConstant( - &runtime.Field{Index: index, Path: path}, - )) - return + if ident, isIdent := base.(*ast.IdentifierNode); isIdent { + if ok, identIndex, name := checker.FieldIndex(c.types, ident); ok { + index = append(identIndex, index...) + path = append([]string{name}, path...) + c.emitLocation(ident.Location(), OpLoadField, c.addConstant( + &runtime.Field{Index: index, Path: path}, + )) + return + } } - member, ok := base.(*ast.MemberNode) - if ok && len(member.FieldIndex) > 0 { - index = append(member.FieldIndex, index...) - path = append([]string{member.Name}, path...) - node = member - base = member.Node + + if member, isMember := base.(*ast.MemberNode); isMember { + if ok, memberIndex, name := checker.FieldIndex(c.types, member); ok { + index = append(memberIndex, index...) + path = append([]string{name}, path...) + node = member + base = member.Node + } else { + break + } } else { break } diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index c3b460aa0..8068f5e69 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -24,6 +24,10 @@ type B struct { } } +func (B) FuncInB() int { + return 0 +} + type Env struct { A struct { _ byte @@ -33,12 +37,20 @@ type Env struct { } } +// AFunc is a method what goes before Func in the alphabet. +func (e Env) AFunc() int { + return 0 +} + +func (e Env) Func() B { + return B{} +} + func TestCompile(t *testing.T) { - type test struct { - input string - program vm.Program - } - var tests = []test{ + var tests = []struct { + code string + want vm.Program + }{ { `65535`, vm.Program{ @@ -271,13 +283,53 @@ func TestCompile(t *testing.T) { Arguments: []int{0, 0, 1, 0}, }, }, + { + `Func()`, + vm.Program{ + Constants: []any{ + &runtime.Method{ + Index: 1, + Name: "Func", + }, + }, + Bytecode: []vm.Opcode{ + vm.OpLoadMethod, + vm.OpCall, + }, + Arguments: []int{0, 0}, + }, + }, + { + `Func().FuncInB()`, + vm.Program{ + Constants: []any{ + &runtime.Method{ + Index: 1, + Name: "Func", + }, + &runtime.Method{ + Index: 0, + Name: "FuncInB", + }, + }, + Bytecode: []vm.Opcode{ + vm.OpLoadMethod, + vm.OpCall, + vm.OpMethod, + vm.OpCallTyped, + }, + Arguments: []int{0, 0, 1, 10}, + }, + }, } for _, test := range tests { - program, err := expr.Compile(test.input, expr.Env(Env{}), expr.Optimize(false)) - require.NoError(t, err, test.input) + t.Run(test.code, func(t *testing.T) { + program, err := expr.Compile(test.code, expr.Env(Env{}), expr.Optimize(false)) + require.NoError(t, err) - assert.Equal(t, test.program.Disassemble(), program.Disassemble(), test.input) + assert.Equal(t, test.want.Disassemble(), program.Disassemble()) + }) } } diff --git a/test/interface_method/interface_method_test.go b/test/interface_method/interface_method_test.go index 23d9862d6..51057f7a8 100644 --- a/test/interface_method/interface_method_test.go +++ b/test/interface_method/interface_method_test.go @@ -3,9 +3,10 @@ package interface_method_test import ( "testing" - "github.com/expr-lang/expr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/expr-lang/expr" ) type Bar interface { @@ -39,7 +40,6 @@ func TestInterfaceMethod(t *testing.T) { "var": FooImpl{}, } p, err := expr.Compile(`var.Foo().Bar()`, expr.Env(env)) - assert.NoError(t, err) out, err := expr.Run(p, env) diff --git a/test/operator/operator_test.go b/test/operator/operator_test.go index fc5c3b525..af50a24eb 100644 --- a/test/operator/operator_test.go +++ b/test/operator/operator_test.go @@ -4,9 +4,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/expr-lang/expr" "github.com/expr-lang/expr/test/mock" - "github.com/stretchr/testify/require" ) func TestOperator_struct(t *testing.T) {