Skip to content

Commit c2c0b01

Browse files
authored
Add public accessors for request pattern and method (#175)
These are very useful values to be able to access easily while processing requests. Let's make them public and reachable via the context.
1 parent ab6ccb7 commit c2c0b01

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

router.go

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ import (
1414
// https://play.golang.org/p/MxhRiL37R-9
1515
type routerContextKeyType struct{}
1616
type routerRequestPatternContextKeyType struct{}
17+
type routerRequestMethodContextKeyType struct{}
1718

1819
var (
1920
routerContextKey = routerContextKeyType{}
2021
routerRequestPatternContextKey = routerRequestPatternContextKeyType{}
22+
routerRequestMethodContextKey = routerRequestMethodContextKeyType{}
2123
routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`)
2224
)
2325

@@ -53,6 +55,22 @@ func routerPathPatternForRequest(r Request) string {
5355
return ""
5456
}
5557

58+
// RequestPatternFromContext returns the pattern that was matched for the request, if available.
59+
func RequestPatternFromContext(ctx context.Context) (string, bool) {
60+
if v := ctx.Value(routerRequestPatternContextKey); v != nil {
61+
return v.(string), true
62+
}
63+
return "", false
64+
}
65+
66+
// RequestMethodFromContext returns the method of the request, if available.
67+
func RequestMethodFromContext(ctx context.Context) (string, bool) {
68+
if v := ctx.Value(routerRequestMethodContextKey); v != nil {
69+
return v.(string), true
70+
}
71+
return "", false
72+
}
73+
5674
func (r *Router) compile(pattern string) *regexp.Regexp {
5775
re, pos := ``, 0
5876
for _, m := range routerComponentsRe.FindAllStringSubmatchIndex(pattern, -1) {
@@ -134,6 +152,7 @@ func (r Router) Serve() Service {
134152
}
135153
req.Context = context.WithValue(req.Context, routerContextKey, &r)
136154
req.Context = context.WithValue(req.Context, routerRequestPatternContextKey, pathPattern)
155+
req.Context = context.WithValue(req.Context, routerRequestMethodContextKey, req.Method)
137156
rsp := svc(req)
138157
if rsp.Request == nil {
139158
rsp.Request = &req
@@ -157,37 +176,46 @@ func (r Router) Params(req Request) map[string]string {
157176
// Sugar
158177

159178
// GET is shorthand for:
160-
// r.Register("GET", pattern, svc)
179+
//
180+
// r.Register("GET", pattern, svc)
161181
func (r *Router) GET(pattern string, svc Service) { r.Register("GET", pattern, svc) }
162182

163183
// CONNECT is shorthand for:
164-
// r.Register("CONNECT", pattern, svc)
184+
//
185+
// r.Register("CONNECT", pattern, svc)
165186
func (r *Router) CONNECT(pattern string, svc Service) { r.Register("CONNECT", pattern, svc) }
166187

167188
// DELETE is shorthand for:
168-
// r.Register("DELETE", pattern, svc)
189+
//
190+
// r.Register("DELETE", pattern, svc)
169191
func (r *Router) DELETE(pattern string, svc Service) { r.Register("DELETE", pattern, svc) }
170192

171193
// HEAD is shorthand for:
172-
// r.Register("HEAD", pattern, svc)
194+
//
195+
// r.Register("HEAD", pattern, svc)
173196
func (r *Router) HEAD(pattern string, svc Service) { r.Register("HEAD", pattern, svc) }
174197

175198
// OPTIONS is shorthand for:
176-
// r.Register("OPTIONS", pattern, svc)
199+
//
200+
// r.Register("OPTIONS", pattern, svc)
177201
func (r *Router) OPTIONS(pattern string, svc Service) { r.Register("OPTIONS", pattern, svc) }
178202

179203
// PATCH is shorthand for:
180-
// r.Register("PATCH", pattern, svc)
204+
//
205+
// r.Register("PATCH", pattern, svc)
181206
func (r *Router) PATCH(pattern string, svc Service) { r.Register("PATCH", pattern, svc) }
182207

183208
// POST is shorthand for:
184-
// r.Register("POST", pattern, svc)
209+
//
210+
// r.Register("POST", pattern, svc)
185211
func (r *Router) POST(pattern string, svc Service) { r.Register("POST", pattern, svc) }
186212

187213
// PUT is shorthand for:
188-
// r.Register("PUT", pattern, svc)
214+
//
215+
// r.Register("PUT", pattern, svc)
189216
func (r *Router) PUT(pattern string, svc Service) { r.Register("PUT", pattern, svc) }
190217

191218
// TRACE is shorthand for:
192-
// r.Register("TRACE", pattern, svc)
219+
//
220+
// r.Register("TRACE", pattern, svc)
193221
func (r *Router) TRACE(pattern string, svc Service) { r.Register("TRACE", pattern, svc) }

router_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,25 @@ func TestRouterSetsRequest(t *testing.T) {
126126
req.Context = rsp.Request.Context
127127
assert.Equal(t, req, *rsp.Request)
128128
}
129+
130+
func TestRouterSetsContextValues(t *testing.T) {
131+
t.Parallel()
132+
133+
router := Router{}
134+
router.GET("/", func(req Request) Response {
135+
return Response{}
136+
})
137+
138+
ctx := context.Background()
139+
req := NewRequest(ctx, "GET", "/", map[string]string{"r": "foo"})
140+
rsp := router.Serve()(req)
141+
require.NotNil(t, rsp.Request)
142+
143+
ctxPattern, ok := RequestPatternFromContext(rsp.Request.Context)
144+
assert.True(t, ok)
145+
assert.Equal(t, "/", ctxPattern)
146+
147+
ctxMethod, ok := RequestMethodFromContext(rsp.Request.Context)
148+
assert.True(t, ok)
149+
assert.Equal(t, "GET", ctxMethod)
150+
}

0 commit comments

Comments
 (0)