diff --git a/patcher/with_context.go b/patcher/with_context.go index 55b604261..f9861a2c2 100644 --- a/patcher/with_context.go +++ b/patcher/with_context.go @@ -22,11 +22,18 @@ func (w WithContext) Visit(node *ast.Node) { if fn.Kind() != reflect.Func { return } - if fn.NumIn() == 0 { - return - } - if fn.In(0).String() != "context.Context" { + switch fn.NumIn() { + case 0: return + case 1: + if fn.In(0).String() != "context.Context" { + return + } + default: + if fn.In(0).String() != "context.Context" && + fn.In(1).String() != "context.Context" { + return + } } ast.Patch(node, &ast.CallNode{ Callee: call.Callee, diff --git a/patcher/with_context_test.go b/patcher/with_context_test.go index afad4e6f0..5ce64191f 100644 --- a/patcher/with_context_test.go +++ b/patcher/with_context_test.go @@ -62,6 +62,30 @@ func TestWithContext_with_env_Function(t *testing.T) { require.Equal(t, 42, output) } +type testEnvContext struct { + Context context.Context `expr:"ctx"` +} + +func (testEnvContext) Fn(ctx context.Context, a int) int { + return ctx.Value("value").(int) + a +} + +func TestWithContext_env_struct(t *testing.T) { + withContext := patcher.WithContext{Name: "ctx"} + + program, err := expr.Compile(`Fn(40)`, expr.Env(testEnvContext{}), expr.Patch(withContext)) + require.NoError(t, err) + + ctx := context.WithValue(context.Background(), "value", 2) + env := testEnvContext{ + Context: ctx, + } + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 42, output) +} + type TestFoo struct { contextValue int }