Skip to content

Commit e7ec0ea

Browse files
committed
Add SetToken method
1 parent fd98100 commit e7ec0ea

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

data.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,20 @@ func (s *SessionManager) Token(ctx context.Context) string {
603603
return sd.token
604604
}
605605

606+
// SetToken changes the token for the session to a known value. Please take care
607+
// when using this function to ensure that the token you are setting is an
608+
// unguessable value from a trusted source. Most applications will not need to
609+
// use this method.
610+
func (s *SessionManager) SetToken(ctx context.Context, token string) {
611+
sd := s.getSessionDataFromContext(ctx)
612+
613+
sd.mu.Lock()
614+
defer sd.mu.Unlock()
615+
616+
sd.token = token
617+
sd.status = Modified
618+
}
619+
606620
func (s *SessionManager) addSessionDataToContext(ctx context.Context, sd *sessionData) context.Context {
607621
return context.WithValue(ctx, s.ContextKey, sd)
608622
}

session_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,39 @@ func TestIterate(t *testing.T) {
348348
t.Fatal("didn't get expected error")
349349
}
350350
}
351+
352+
func TestSetToken(t *testing.T) {
353+
t.Parallel()
354+
355+
sessionManager := New()
356+
357+
mux := http.NewServeMux()
358+
mux.HandleFunc("/put", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
359+
sessionManager.SetToken(r.Context(), "my-custom-token")
360+
sessionManager.Put(r.Context(), "foo", "bar")
361+
}))
362+
mux.HandleFunc("/get", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
363+
v := sessionManager.Get(r.Context(), "foo")
364+
if v == nil {
365+
http.Error(w, "foo does not exist in session", 500)
366+
return
367+
}
368+
w.Write([]byte(v.(string)))
369+
}))
370+
371+
ts := newTestServer(t, sessionManager.LoadAndSave(mux))
372+
defer ts.Close()
373+
374+
header, _ := ts.execute(t, "/put")
375+
cookie := header.Get("Set-Cookie")
376+
token := extractTokenFromCookie(cookie)
377+
378+
if token != "my-custom-token" {
379+
t.Errorf("want %q; got %q", "my-custom-token", token)
380+
}
381+
382+
_, body := ts.execute(t, "/get")
383+
if body != "bar" {
384+
t.Errorf("want %q; got %q", "bar", body)
385+
}
386+
}

0 commit comments

Comments
 (0)