Skip to content

Commit 29dff34

Browse files
committed
Make all builtins working with deref
Fixes #730
1 parent 8f0751d commit 29dff34

File tree

5 files changed

+74
-42
lines changed

5 files changed

+74
-42
lines changed

builtin/builtin.go

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ var Builtins = []*Function{
164164
if len(args) != 1 {
165165
return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
166166
}
167-
switch kind(deref.Type(args[0])) {
167+
switch kind(args[0]) {
168168
case reflect.Interface:
169169
return integerType, nil
170170
case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
@@ -642,18 +642,21 @@ var Builtins = []*Function{
642642
if len(args) != 2 {
643643
return nil, fmt.Errorf("invalid number of arguments (expected 2, got %d)", len(args))
644644
}
645-
v := reflect.ValueOf(args[0])
645+
v := deref.ValueOf(args[0])
646646
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
647647
return nil, fmt.Errorf("cannot take from %s", v.Kind())
648648
}
649-
n := reflect.ValueOf(args[1])
649+
n := deref.ValueOf(args[1])
650650
if !n.CanInt() {
651651
return nil, fmt.Errorf("cannot take %s elements", n.Kind())
652652
}
653+
to := 0
653654
if n.Int() > int64(v.Len()) {
654-
return args[0], nil
655+
to = v.Len()
656+
} else {
657+
to = int(n.Int())
655658
}
656-
return v.Slice(0, int(n.Int())).Interface(), nil
659+
return v.Slice(0, to).Interface(), nil
657660
},
658661
Validate: func(args []reflect.Type) (reflect.Type, error) {
659662
if len(args) != 2 {
@@ -678,7 +681,7 @@ var Builtins = []*Function{
678681
if len(args) != 1 {
679682
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
680683
}
681-
v := reflect.ValueOf(args[0])
684+
v := deref.ValueOf(args[0])
682685
if v.Kind() != reflect.Map {
683686
return nil, fmt.Errorf("cannot get keys from %s", v.Kind())
684687
}
@@ -708,7 +711,7 @@ var Builtins = []*Function{
708711
if len(args) != 1 {
709712
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
710713
}
711-
v := reflect.ValueOf(args[0])
714+
v := deref.ValueOf(args[0])
712715
if v.Kind() != reflect.Map {
713716
return nil, fmt.Errorf("cannot get values from %s", v.Kind())
714717
}
@@ -738,7 +741,7 @@ var Builtins = []*Function{
738741
if len(args) != 1 {
739742
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
740743
}
741-
v := reflect.ValueOf(args[0])
744+
v := deref.ValueOf(args[0])
742745
if v.Kind() != reflect.Map {
743746
return nil, fmt.Errorf("cannot transform %s to pairs", v.Kind())
744747
}
@@ -766,7 +769,7 @@ var Builtins = []*Function{
766769
if len(args) != 1 {
767770
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
768771
}
769-
v := reflect.ValueOf(args[0])
772+
v := deref.ValueOf(args[0])
770773
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
771774
return nil, fmt.Errorf("cannot transform %s from pairs", v)
772775
}
@@ -798,14 +801,14 @@ var Builtins = []*Function{
798801
},
799802
{
800803
Name: "reverse",
801-
Func: func(args ...any) (any, error) {
804+
Safe: func(args ...any) (any, uint, error) {
802805
if len(args) != 1 {
803-
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
806+
return nil, 0, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
804807
}
805808

806-
v := reflect.ValueOf(args[0])
809+
v := deref.ValueOf(args[0])
807810
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
808-
return nil, fmt.Errorf("cannot reverse %s", v.Kind())
811+
return nil, 0, fmt.Errorf("cannot reverse %s", v.Kind())
809812
}
810813

811814
size := v.Len()
@@ -815,7 +818,7 @@ var Builtins = []*Function{
815818
arr[i] = v.Index(size - i - 1).Interface()
816819
}
817820

818-
return arr, nil
821+
return arr, uint(size), nil
819822

820823
},
821824
Validate: func(args []reflect.Type) (reflect.Type, error) {
@@ -838,7 +841,7 @@ var Builtins = []*Function{
838841
return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
839842
}
840843

841-
v := reflect.ValueOf(deref.Deref(args[0]))
844+
v := deref.ValueOf(deref.Deref(args[0]))
842845
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
843846
return nil, fmt.Errorf("cannot uniq %s", v.Kind())
844847
}
@@ -892,7 +895,7 @@ var Builtins = []*Function{
892895
var arr []any
893896

894897
for _, arg := range args {
895-
v := reflect.ValueOf(deref.Deref(arg))
898+
v := deref.ValueOf(deref.Deref(arg))
896899

897900
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
898901
return nil, 0, fmt.Errorf("cannot concat %s", v.Kind())
@@ -914,7 +917,7 @@ var Builtins = []*Function{
914917
}
915918

916919
for _, arg := range args {
917-
switch kind(deref.Type(arg)) {
920+
switch kind(arg) {
918921
case reflect.Interface, reflect.Slice, reflect.Array:
919922
default:
920923
return anyType, fmt.Errorf("cannot concat %s", arg)
@@ -931,7 +934,7 @@ var Builtins = []*Function{
931934
if len(args) != 1 {
932935
return nil, 0, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args))
933936
}
934-
v := reflect.ValueOf(deref.Deref(args[0]))
937+
v := deref.ValueOf(deref.Deref(args[0]))
935938
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
936939
return nil, size, fmt.Errorf("cannot flatten %s", v.Kind())
937940
}
@@ -945,7 +948,7 @@ var Builtins = []*Function{
945948
}
946949

947950
for _, arg := range args {
948-
switch kind(deref.Type(arg)) {
951+
switch kind(arg) {
949952
case reflect.Interface, reflect.Slice, reflect.Array:
950953
default:
951954
return anyType, fmt.Errorf("cannot flatten %s", arg)

builtin/builtin_test.go

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -639,15 +639,46 @@ func Test_int_unwraps_underlying_value(t *testing.T) {
639639
assert.Equal(t, true, out)
640640
}
641641

642-
func TestBuiltin_int_with_deref(t *testing.T) {
642+
func TestBuiltin_with_deref(t *testing.T) {
643643
x := 42
644+
arr := []any{1, 2, 3}
645+
m := map[string]any{"a": 1, "b": 2}
644646
env := map[string]any{
645-
"x": &x,
647+
"x": &x,
648+
"arr": &arr,
649+
"m": &m,
646650
}
647-
program, err := expr.Compile(`int(x)`, expr.Env(env))
648-
require.NoError(t, err)
649651

650-
out, err := expr.Run(program, env)
651-
require.NoError(t, err)
652-
assert.Equal(t, 42, out)
652+
tests := []struct {
653+
input string
654+
want any
655+
}{
656+
{`int(x)`, 42},
657+
{`float(x)`, 42.0},
658+
{`abs(x)`, 42},
659+
{`first(arr)`, 1},
660+
{`last(arr)`, 3},
661+
{`take(arr, 1)`, []any{1}},
662+
{`take(arr, x)`, []any{1, 2, 3}},
663+
{`'a' in keys(m)`, true},
664+
{`1 in values(m)`, true},
665+
{`len(arr)`, 3},
666+
{`type(arr)`, "array"},
667+
{`type(m)`, "map"},
668+
{`reverse(arr)`, []any{3, 2, 1}},
669+
{`uniq(arr)`, []any{1, 2, 3}},
670+
{`concat(arr, arr)`, []any{1, 2, 3, 1, 2, 3}},
671+
{`flatten([arr, [arr]])`, []any{1, 2, 3, 1, 2, 3}},
672+
}
673+
674+
for _, test := range tests {
675+
t.Run(test.input, func(t *testing.T) {
676+
program, err := expr.Compile(test.input, expr.Env(env))
677+
require.NoError(t, err)
678+
679+
out, err := expr.Run(program, env)
680+
require.NoError(t, err)
681+
assert.Equal(t, test.want, out)
682+
})
683+
}
653684
}

builtin/lib.go

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
func Len(x any) any {
14-
v := reflect.ValueOf(x)
14+
v := deref.ValueOf(x)
1515
switch v.Kind() {
1616
case reflect.Array, reflect.Slice, reflect.Map:
1717
return v.Len()
@@ -26,16 +26,7 @@ func Type(arg any) any {
2626
if arg == nil {
2727
return "nil"
2828
}
29-
v := reflect.ValueOf(arg)
30-
for {
31-
if v.Kind() == reflect.Ptr {
32-
v = v.Elem()
33-
} else if v.Kind() == reflect.Interface {
34-
v = v.Elem()
35-
} else {
36-
break
37-
}
38-
}
29+
v := deref.ValueOf(arg)
3930
if v.Type().Name() != "" && v.Type().PkgPath() != "" {
4031
return fmt.Sprintf("%s.%s", v.Type().PkgPath(), v.Type().Name())
4132
}
@@ -66,7 +57,7 @@ func Type(arg any) any {
6657
}
6758

6859
func Abs(x any) any {
69-
switch x := x.(type) {
60+
switch x := deref.Deref(x).(type) {
7061
case float32:
7162
if x < 0 {
7263
return -x
@@ -221,7 +212,7 @@ func Int(x any) any {
221212
}
222213

223214
func Float(x any) any {
224-
switch x := x.(type) {
215+
switch x := deref.Deref(x).(type) {
225216
case float32:
226217
return float64(x)
227218
case float64:
@@ -264,7 +255,7 @@ func String(arg any) any {
264255
func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
265256
var val any
266257
for _, arg := range args {
267-
rv := reflect.ValueOf(deref.Deref(arg))
258+
rv := deref.ValueOf(arg)
268259
switch rv.Kind() {
269260
case reflect.Array, reflect.Slice:
270261
size := rv.Len()
@@ -307,7 +298,7 @@ func mean(args ...any) (int, float64, error) {
307298
var count int
308299

309300
for _, arg := range args {
310-
rv := reflect.ValueOf(deref.Deref(arg))
301+
rv := deref.ValueOf(arg)
311302
switch rv.Kind() {
312303
case reflect.Array, reflect.Slice:
313304
size := rv.Len()
@@ -339,7 +330,7 @@ func median(args ...any) ([]float64, error) {
339330
var values []float64
340331

341332
for _, arg := range args {
342-
rv := reflect.ValueOf(deref.Deref(arg))
333+
rv := deref.ValueOf(arg)
343334
switch rv.Kind() {
344335
case reflect.Array, reflect.Slice:
345336
size := rv.Len()

builtin/utils.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"fmt"
55
"reflect"
66
"time"
7+
8+
"github.com/expr-lang/expr/internal/deref"
79
)
810

911
var (
@@ -20,6 +22,7 @@ func kind(t reflect.Type) reflect.Kind {
2022
if t == nil {
2123
return reflect.Invalid
2224
}
25+
t = deref.Type(t)
2326
return t.Kind()
2427
}
2528

internal/deref/deref.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ func Value(v reflect.Value) reflect.Value {
4545
}
4646
return v
4747
}
48+
49+
func ValueOf(v any) reflect.Value {
50+
return Value(reflect.ValueOf(v))
51+
}

0 commit comments

Comments
 (0)