diff --git a/ast/node.go b/ast/node.go index f70c49cc7..68c9e7bef 100644 --- a/ast/node.go +++ b/ast/node.go @@ -2,7 +2,6 @@ package ast import ( "reflect" - "regexp" "github.com/expr-lang/expr/file" ) @@ -104,10 +103,9 @@ type UnaryNode struct { // BinaryNode represents a binary operator. type BinaryNode struct { base - Operator string // Operator of the binary operator. Like "+" in "foo + bar" or "matches" in "foo matches bar". - Left Node // Left node of the binary operator. - Right Node // Right node of the binary operator. - Regexp *regexp.Regexp // Internal. Regexp of the "matches" operator. Like "f.+". + Operator string // Operator of the binary operator. Like "+" in "foo + bar" or "matches" in "foo matches bar". + Left Node // Left node of the binary operator. + Right Node // Right node of the binary operator. } // ChainNode represents an optional chaining group. @@ -151,11 +149,8 @@ type SliceNode struct { // CallNode represents a function or a method call. type CallNode struct { base - Callee Node // Node of the call. Like "foo" in "foo()". - Arguments []Node // Arguments of the call. - Typed int // Internal. Used to indicate compiler what type is one of vm.FuncTypes. - Fast bool // Internal. Used to indicate compiler what this call is a fast call. - Func *Function // Internal. Used to pass function information from type checker to compiler. + Callee Node // Node of the call. Like "foo" in "foo()". + Arguments []Node // Arguments of the call. } // BuiltinNode represents a builtin function call. @@ -163,8 +158,8 @@ type BuiltinNode struct { base Name string // Name of the builtin function. Like "len" in "len(foo)". Arguments []Node // Arguments of the builtin function. - Throws bool // Internal. If true then accessing a field or array index can throw an error. Used by optimizer. - Map Node // Internal. Used by optimizer to fold filter() and map() builtins. + Throws bool // If true then accessing a field or array index can throw an error. Used by optimizer. + Map Node // Used by optimizer to fold filter() and map() builtins. } // ClosureNode represents a predicate. diff --git a/checker/checker.go b/checker/checker.go index 38490bacf..5afe18283 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -10,7 +10,6 @@ import ( "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/parser" - "github.com/expr-lang/expr/vm" ) func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { @@ -374,11 +373,10 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { case "matches": if s, ok := node.Right.(*ast.StringNode); ok { - r, err := regexp.Compile(s.Value) + _, err := regexp.Compile(s.Value) if err != nil { return v.error(node, err.Error()) } - node.Regexp = r } if isString(l) && isString(r) { return boolType, info{} @@ -549,7 +547,6 @@ func (v *checker) functionReturnType(node *ast.CallNode) (reflect.Type, info) { fn, fnInfo := v.visit(node.Callee) if fnInfo.fn != nil { - node.Func = fnInfo.fn return v.checkFunction(fnInfo.fn, node, node.Arguments) } @@ -571,23 +568,6 @@ func (v *checker) functionReturnType(node *ast.CallNode) (reflect.Type, info) { case reflect.Interface: return anyType, info{} case reflect.Func: - inputParamsCount := 1 // for functions - if fnInfo.method { - inputParamsCount = 2 // for methods - } - // TODO: Deprecate OpCallFast and move fn(...any) any to TypedFunc list. - // To do this we need add support for variadic arguments in OpCallTyped. - if !isAny(fn) && - fn.IsVariadic() && - fn.NumIn() == inputParamsCount && - fn.NumOut() == 1 && - fn.Out(0).Kind() == reflect.Interface { - rest := fn.In(fn.NumIn() - 1) // function has only one param for functions and two for methods - if kind(rest) == reflect.Slice && rest.Elem().Kind() == reflect.Interface { - node.Fast = true - } - } - outType, err := v.checkArguments(fnName, fn, fnInfo.method, node.Arguments, node) if err != nil { if v.err == nil { @@ -595,9 +575,6 @@ func (v *checker) functionReturnType(node *ast.CallNode) (reflect.Type, info) { } return anyType, info{} } - - v.findTypedFunc(node, fn, fnInfo.method) - return outType, info{} } return v.error(node, "%v is not callable", fn) @@ -883,7 +860,13 @@ func (v *checker) checkFunction(f *ast.Function, node ast.Node, arguments []ast. return v.error(node, "no matching overload for %v", f.Name) } -func (v *checker) checkArguments(name string, fn reflect.Type, method bool, arguments []ast.Node, node ast.Node) (reflect.Type, *file.Error) { +func (v *checker) checkArguments( + name string, + fn reflect.Type, + method bool, + arguments []ast.Node, + node ast.Node, +) (reflect.Type, *file.Error) { if isAny(fn) { return anyType, nil } @@ -1122,44 +1105,3 @@ func (v *checker) PairNode(node *ast.PairNode) (reflect.Type, info) { v.visit(node.Value) return nilType, info{} } - -func (v *checker) findTypedFunc(node *ast.CallNode, fn reflect.Type, method bool) { - // OnCallTyped doesn't work for functions with variadic arguments, - // and doesn't work named function, like `type MyFunc func() int`. - // In PkgPath() is an empty string, it's unnamed function. - if !fn.IsVariadic() && fn.PkgPath() == "" { - fnNumIn := fn.NumIn() - fnInOffset := 0 - if method { - fnNumIn-- - fnInOffset = 1 - } - funcTypes: - for i := range vm.FuncTypes { - if i == 0 { - continue - } - typed := reflect.ValueOf(vm.FuncTypes[i]).Elem().Type() - if typed.Kind() != reflect.Func { - continue - } - if typed.NumOut() != fn.NumOut() { - continue - } - for j := 0; j < typed.NumOut(); j++ { - if typed.Out(j) != fn.Out(j) { - continue funcTypes - } - } - if typed.NumIn() != fnNumIn { - continue - } - for j := 0; j < typed.NumIn(); j++ { - if typed.In(j) != fn.In(j+fnInOffset) { - continue funcTypes - } - } - node.Typed = i - } - } -} diff --git a/checker/checker_test.go b/checker/checker_test.go index 42301261d..d03e3a8ee 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -554,6 +554,11 @@ unknown pointer #unknown (1:11) cannot use int as type string in array (1:4) | 42 in ["a", "b", "c"] | ...^ + +"foo" matches "[+" +error parsing regexp: missing closing ]: ` + "`[+`" + ` (1:7) + | "foo" matches "[+" + | ......^ ` func TestCheck_error(t *testing.T) { @@ -777,49 +782,6 @@ func TestCheck_TypeWeights(t *testing.T) { } } -func TestCheck_CallFastTyped(t *testing.T) { - env := map[string]any{ - "fn": func([]any, string) string { - return "foo" - }, - } - - tree, err := parser.Parse("fn([1, 2], 'bar')") - require.NoError(t, err) - - _, err = checker.Check(tree, conf.New(env)) - require.NoError(t, err) - - require.False(t, tree.Node.(*ast.CallNode).Fast) - require.Equal(t, 22, tree.Node.(*ast.CallNode).Typed) -} - -func TestCheck_CallFastTyped_Method(t *testing.T) { - env := mock.Env{} - - tree, err := parser.Parse("FuncTyped('bar')") - require.NoError(t, err) - - _, err = checker.Check(tree, conf.New(env)) - require.NoError(t, err) - - require.False(t, tree.Node.(*ast.CallNode).Fast) - require.Equal(t, 42, tree.Node.(*ast.CallNode).Typed) -} - -func TestCheck_CallTyped_excludes_named_functions(t *testing.T) { - env := mock.Env{} - - tree, err := parser.Parse("FuncNamed('bar')") - require.NoError(t, err) - - _, err = checker.Check(tree, conf.New(env)) - require.NoError(t, err) - - require.False(t, tree.Node.(*ast.CallNode).Fast) - require.Equal(t, 0, tree.Node.(*ast.CallNode).Typed) -} - func TestCheck_works_with_nil_types(t *testing.T) { env := map[string]any{ "null": nil, @@ -908,8 +870,7 @@ func TestCheck_Function_types_are_checked(t *testing.T) { _, err = checker.Check(tree, config) require.NoError(t, err) - require.NotNil(t, tree.Node.(*ast.CallNode).Func) - require.Equal(t, "add", tree.Node.(*ast.CallNode).Func.Name) + require.Equal(t, reflect.Int, tree.Node.Type().Kind()) }) } @@ -943,8 +904,7 @@ func TestCheck_Function_without_types(t *testing.T) { _, err = checker.Check(tree, config) require.NoError(t, err) - require.NotNil(t, tree.Node.(*ast.CallNode).Func) - require.Equal(t, "add", tree.Node.(*ast.CallNode).Func.Name) + require.Equal(t, reflect.Interface, tree.Node.Type().Kind()) } func TestCheck_dont_panic_on_nil_arguments_for_builtins(t *testing.T) { diff --git a/checker/info.go b/checker/info.go index 3245253a3..112bfab31 100644 --- a/checker/info.go +++ b/checker/info.go @@ -5,6 +5,7 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/conf" + "github.com/expr-lang/expr/vm" ) func FieldIndex(types conf.TypesTable, node ast.Node) (bool, []int, string) { @@ -48,3 +49,80 @@ func MethodIndex(types conf.TypesTable, node ast.Node) (bool, int, string) { } return false, 0, "" } + +func TypedFuncIndex(fn reflect.Type, method bool) (int, bool) { + if fn == nil { + return 0, false + } + if fn.Kind() != reflect.Func { + return 0, false + } + // OnCallTyped doesn't work for functions with variadic arguments. + if fn.IsVariadic() { + return 0, false + } + // OnCallTyped doesn't work named function, like `type MyFunc func() int`. + if fn.PkgPath() != "" { // If PkgPath() is not empty, it means that function is named. + return 0, false + } + + fnNumIn := fn.NumIn() + fnInOffset := 0 + if method { + fnNumIn-- + fnInOffset = 1 + } + +funcTypes: + for i := range vm.FuncTypes { + if i == 0 { + continue + } + typed := reflect.ValueOf(vm.FuncTypes[i]).Elem().Type() + if typed.Kind() != reflect.Func { + continue + } + if typed.NumOut() != fn.NumOut() { + continue + } + for j := 0; j < typed.NumOut(); j++ { + if typed.Out(j) != fn.Out(j) { + continue funcTypes + } + } + if typed.NumIn() != fnNumIn { + continue + } + for j := 0; j < typed.NumIn(); j++ { + if typed.In(j) != fn.In(j+fnInOffset) { + continue funcTypes + } + } + return i, true + } + return 0, false +} + +func IsFastFunc(fn reflect.Type, method bool) bool { + if fn == nil { + return false + } + if fn.Kind() != reflect.Func { + return false + } + numIn := 1 + if method { + numIn = 2 + } + if !isAny(fn) && + fn.IsVariadic() && + fn.NumIn() == numIn && + fn.NumOut() == 1 && + fn.Out(0).Kind() == reflect.Interface { + rest := fn.In(fn.NumIn() - 1) // function has only one param for functions and two for methods + if kind(rest) == reflect.Slice && rest.Elem().Kind() == reflect.Interface { + return true + } + } + return false +} diff --git a/checker/info_test.go b/checker/info_test.go new file mode 100644 index 000000000..c91a55ad6 --- /dev/null +++ b/checker/info_test.go @@ -0,0 +1,27 @@ +package checker_test + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/expr-lang/expr/checker" + "github.com/expr-lang/expr/test/mock" +) + +func TestTypedFuncIndex(t *testing.T) { + fn := func([]any, string) string { + return "foo" + } + index, ok := checker.TypedFuncIndex(reflect.TypeOf(fn), false) + require.True(t, ok) + require.Equal(t, 22, index) +} + +func TestTypedFuncIndex_excludes_named_functions(t *testing.T) { + var fn mock.MyFunc + + _, ok := checker.TypedFuncIndex(reflect.TypeOf(fn), false) + require.False(t, ok) +} diff --git a/compiler/compiler.go b/compiler/compiler.go index 2b03e5afa..50c4816b2 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -3,6 +3,7 @@ package compiler import ( "fmt" "reflect" + "regexp" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" @@ -26,27 +27,24 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro }() c := &compiler{ + config: config, locations: make([]file.Location, 0), constantsIndex: make(map[any]int), functionsIndex: make(map[string]int), debugInfo: make(map[string]string), } - if config != nil { - c.mapEnv = config.MapEnv - c.cast = config.Expect - c.types = config.Types - } - c.compile(tree.Node) - switch c.cast { - case reflect.Int: - c.emit(OpCast, 0) - case reflect.Int64: - c.emit(OpCast, 1) - case reflect.Float64: - c.emit(OpCast, 2) + if c.config != nil { + switch c.config.Expect { + case reflect.Int: + c.emit(OpCast, 0) + case reflect.Int64: + c.emit(OpCast, 1) + case reflect.Float64: + c.emit(OpCast, 2) + } } program = NewProgram( @@ -63,6 +61,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro } type compiler struct { + config *conf.Config locations []file.Location bytecode []Opcode variables []any @@ -72,12 +71,9 @@ type compiler struct { functions []Function functionsIndex map[string]int debugInfo map[string]string - mapEnv bool - cast reflect.Kind nodes []ast.Node chains [][]int arguments []int - types conf.TypesTable } type scope struct { @@ -255,14 +251,22 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) { c.emit(OpLoadEnv) return } - if c.mapEnv { + + var mapEnv bool + var types conf.TypesTable + if c.config != nil { + mapEnv = c.config.MapEnv + types = c.config.Types + } + + if mapEnv { c.emit(OpLoadFast, c.addConstant(node.Value)) - } else if ok, index, name := checker.FieldIndex(c.types, node); ok { + } else if ok, index, name := checker.FieldIndex(types, node); ok { c.emit(OpLoadField, c.addConstant(&runtime.Field{ Index: index, Path: []string{name}, })) - } else if ok, index, name := checker.MethodIndex(c.types, node); ok { + } else if ok, index, name := checker.MethodIndex(types, node); ok { c.emit(OpLoadMethod, c.addConstant(&runtime.Method{ Name: name, Index: index, @@ -485,10 +489,14 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) { c.emit(OpIn) case "matches": - if node.Regexp != nil { + if str, ok := node.Right.(*ast.StringNode); ok { + re, err := regexp.Compile(str.Value) + if err != nil { + panic(err) + } c.compile(node.Left) c.derefInNeeded(node.Left) - c.emit(OpMatchesConst, c.addConstant(node.Regexp)) + c.emit(OpMatchesConst, c.addConstant(re)) } else { c.compile(node.Left) c.derefInNeeded(node.Left) @@ -562,7 +570,12 @@ func (c *compiler) ChainNode(node *ast.ChainNode) { } func (c *compiler) MemberNode(node *ast.MemberNode) { - if ok, index, name := checker.MethodIndex(c.types, node); ok { + var types conf.TypesTable + if c.config != nil { + types = c.config.Types + } + + if ok, index, name := checker.MethodIndex(types, node); ok { c.compile(node.Node) c.emit(OpMethod, c.addConstant(&runtime.Method{ Name: name, @@ -573,14 +586,14 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { op := OpFetch base := node.Node - ok, index, nodeName := checker.FieldIndex(c.types, node) + ok, index, nodeName := checker.FieldIndex(types, node) path := []string{nodeName} if ok { op = OpFetchField for !node.Optional { if ident, isIdent := base.(*ast.IdentifierNode); isIdent { - if ok, identIndex, name := checker.FieldIndex(c.types, ident); ok { + if ok, identIndex, name := checker.FieldIndex(types, ident); ok { index = append(identIndex, index...) path = append([]string{name}, path...) c.emitLocation(ident.Location(), OpLoadField, c.addConstant( @@ -591,7 +604,7 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { } if member, isMember := base.(*ast.MemberNode); isMember { - if ok, memberIndex, name := checker.FieldIndex(c.types, member); ok { + if ok, memberIndex, name := checker.FieldIndex(types, member); ok { index = append(memberIndex, index...) path = append([]string{name}, path...) node = member @@ -640,15 +653,21 @@ func (c *compiler) CallNode(node *ast.CallNode) { for _, arg := range node.Arguments { c.compile(arg) } - if node.Func != nil { - c.emitFunction(node.Func, len(node.Arguments)) - return + if ident, ok := node.Callee.(*ast.IdentifierNode); ok { + if c.config != nil { + if fn, ok := c.config.Functions[ident.Value]; ok { + c.emitFunction(fn, len(node.Arguments)) + return + } + } } c.compile(node.Callee) - if node.Typed > 0 { - c.emit(OpCallTyped, node.Typed) + + isMethod, _, _ := checker.MethodIndex(c.config.Types, node.Callee) + if index, ok := checker.TypedFuncIndex(node.Callee.Type(), isMethod); ok { + c.emit(OpCallTyped, index) return - } else if node.Fast { + } else if checker.IsFastFunc(node.Callee.Type(), isMethod) { c.emit(OpCallFast, len(node.Arguments)) } else { c.emit(OpCall, len(node.Arguments)) diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 8068f5e69..ed11a9dd7 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/expr-lang/expr" + "github.com/expr-lang/expr/test/mock" "github.com/expr-lang/expr/test/playground" "github.com/expr-lang/expr/vm" "github.com/expr-lang/expr/vm/runtime" @@ -347,3 +348,39 @@ func TestCompile_panic(t *testing.T) { }) } } + +func TestCompile_FuncTypes(t *testing.T) { + env := map[string]any{ + "fn": func([]any, string) string { + return "foo" + }, + } + program, err := expr.Compile("fn([1, 2], 'bar')", expr.Env(env)) + require.NoError(t, err) + require.Equal(t, vm.OpCallTyped, program.Bytecode[3]) + require.Equal(t, 22, program.Arguments[3]) +} + +func TestCompile_FuncTypes_with_Method(t *testing.T) { + env := mock.Env{} + program, err := expr.Compile("FuncTyped('bar')", expr.Env(env)) + require.NoError(t, err) + require.Equal(t, vm.OpCallTyped, program.Bytecode[2]) + require.Equal(t, 42, program.Arguments[2]) +} + +func TestCompile_FuncTypes_excludes_named_functions(t *testing.T) { + env := mock.Env{} + program, err := expr.Compile("FuncNamed('bar')", expr.Env(env)) + require.NoError(t, err) + require.Equal(t, vm.OpCall, program.Bytecode[2]) + require.Equal(t, 1, program.Arguments[2]) +} + +func TestCompile_OpCallFast(t *testing.T) { + env := mock.Env{} + program, err := expr.Compile("Fast(3, 2, 1)", expr.Env(env)) + require.NoError(t, err) + require.Equal(t, vm.OpCallFast, program.Bytecode[4]) + require.Equal(t, 3, program.Arguments[4]) +} diff --git a/parser/parser_test.go b/parser/parser_test.go index ba8e29acc..453fe91ac 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -222,14 +222,12 @@ world`}, }, }, Arguments: []Node{}, - Fast: false, }, Property: &StringNode{ Value: "c", }, }, Arguments: []Node{}, - Fast: false, }, Property: &StringNode{ Value: "d", diff --git a/vm/vm_test.go b/vm/vm_test.go index 70a0ae057..0cbbb8998 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/checker" "github.com/expr-lang/expr/compiler" "github.com/expr-lang/expr/conf" @@ -145,10 +144,6 @@ func (ErrorEnv) WillError(param string) (bool, error) { return true, nil } -func (ErrorEnv) FastError(...any) any { - return true -} - func (InnerEnv) WillError(param string) (bool, error) { if param == "yes" { return false, errors.New("inner error") @@ -202,27 +197,6 @@ func TestRun_FastMethods(t *testing.T) { require.Equal(t, "hello world", out) } -func TestRun_FastMethodWithError(t *testing.T) { - input := `FastError()` - - tree, err := parser.Parse(input) - require.NoError(t, err) - - env := ErrorEnv{} - funcConf := conf.New(env) - _, err = checker.Check(tree, funcConf) - require.NoError(t, err) - require.True(t, tree.Node.(*ast.CallNode).Fast, "method must be fast") - - program, err := compiler.Compile(tree, funcConf) - require.NoError(t, err) - - out, err := vm.Run(program, env) - require.NoError(t, err) - - require.Equal(t, true, out) -} - func TestRun_InnerMethodWithError(t *testing.T) { input := `InnerEnv.WillError("yes")`