Skip to content

Commit c19fea2

Browse files
authored
Buffer body read up to MaxRequestSize (#24354) (#24365)
* Buffer body read up to MaxRequestSize (#24354) * adding back a context
1 parent 79f170d commit c19fea2

File tree

8 files changed

+167
-82
lines changed

8 files changed

+167
-82
lines changed

helper/forwarding/util.go

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ import (
44
"bytes"
55
"crypto/tls"
66
"crypto/x509"
7-
"errors"
87
"io"
9-
"io/ioutil"
108
"net/http"
119
"net/url"
1210
"os"
@@ -60,19 +58,7 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request
6058

6159
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
6260
var reader io.Reader = req.Body
63-
ctx := req.Context()
64-
maxRequestSize := ctx.Value("max_request_size")
65-
if maxRequestSize != nil {
66-
max, ok := maxRequestSize.(int64)
67-
if !ok {
68-
return nil, errors.New("could not parse max_request_size from request context")
69-
}
70-
if max > 0 {
71-
reader = io.LimitReader(req.Body, max)
72-
}
73-
}
74-
75-
body, err := ioutil.ReadAll(reader)
61+
body, err := io.ReadAll(reader)
7662
if err != nil {
7763
return nil, err
7864
}

http/handler.go

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,13 @@ func handler(props *vault.HandlerProperties) http.Handler {
226226
corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core)
227227
quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core)
228228
genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props)
229+
wrappedHandler := wrapMaxRequestSizeHandler(genericWrappedHandler, props)
229230

230231
// Wrap the handler with PrintablePathCheckHandler to check for non-printable
231232
// characters in the request path.
232-
printablePathCheckHandler := genericWrappedHandler
233+
printablePathCheckHandler := wrappedHandler
233234
if !props.DisablePrintableCheck {
234-
printablePathCheckHandler = cleanhttp.PrintablePathCheckHandler(genericWrappedHandler, nil)
235+
printablePathCheckHandler = cleanhttp.PrintablePathCheckHandler(wrappedHandler, nil)
235236
}
236237

237238
return printablePathCheckHandler
@@ -310,18 +311,12 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
310311
// are performed.
311312
func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerProperties) http.Handler {
312313
var maxRequestDuration time.Duration
313-
var maxRequestSize int64
314314
if props.ListenerConfig != nil {
315315
maxRequestDuration = props.ListenerConfig.MaxRequestDuration
316-
maxRequestSize = props.ListenerConfig.MaxRequestSize
317316
}
318317
if maxRequestDuration == 0 {
319318
maxRequestDuration = vault.DefaultMaxRequestDuration
320319
}
321-
if maxRequestSize == 0 {
322-
maxRequestSize = DefaultMaxRequestSize
323-
}
324-
325320
// Swallow this error since we don't want to pollute the logs and we also don't want to
326321
// return an HTTP error here. This information is best effort.
327322
hostname, _ := os.Hostname()
@@ -355,11 +350,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
355350
} else {
356351
ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration)
357352
}
358-
// if maxRequestSize < 0, no need to set context value
359-
// Add a size limiter if desired
360-
if maxRequestSize > 0 {
361-
ctx = context.WithValue(ctx, "max_request_size", maxRequestSize)
362-
}
363353
ctx = context.WithValue(ctx, "original_request_path", r.URL.Path)
364354
r = r.WithContext(ctx)
365355
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
@@ -703,25 +693,7 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
703693
// Limit the maximum number of bytes to MaxRequestSize to protect
704694
// against an indefinite amount of data being read.
705695
reader := r.Body
706-
ctx := r.Context()
707-
maxRequestSize := ctx.Value("max_request_size")
708-
if maxRequestSize != nil {
709-
max, ok := maxRequestSize.(int64)
710-
if !ok {
711-
return nil, errors.New("could not parse max_request_size from request context")
712-
}
713-
if max > 0 {
714-
// MaxBytesReader won't do all the internal stuff it must unless it's
715-
// given a ResponseWriter that implements the internal http interface
716-
// requestTooLarger. So we let it have access to the underlying
717-
// ResponseWriter.
718-
inw := w
719-
if myw, ok := inw.(logical.WrappingResponseWriter); ok {
720-
inw = myw.Wrapped()
721-
}
722-
reader = http.MaxBytesReader(inw, r.Body, max)
723-
}
724-
}
696+
725697
var origBody io.ReadWriter
726698
if perfStandby {
727699
// Since we're checking PerfStandby here we key on origBody being nil
@@ -743,16 +715,6 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
743715
//
744716
// A nil map will be returned if the format is empty or invalid.
745717
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
746-
maxRequestSize := r.Context().Value("max_request_size")
747-
if maxRequestSize != nil {
748-
max, ok := maxRequestSize.(int64)
749-
if !ok {
750-
return nil, errors.New("could not parse max_request_size from request context")
751-
}
752-
if max > 0 {
753-
r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max))
754-
}
755-
}
756718
if err := r.ParseForm(); err != nil {
757719
return nil, err
758720
}

http/handler_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package http
22

33
import (
4+
"bytes"
45
"context"
56
"crypto/tls"
67
"encoding/json"
@@ -11,16 +12,19 @@ import (
1112
"net/textproto"
1213
"net/url"
1314
"reflect"
15+
"runtime"
1416
"strings"
1517
"testing"
1618

1719
"github.com/go-test/deep"
1820
"github.com/hashicorp/go-cleanhttp"
1921
"github.com/hashicorp/vault/helper/namespace"
2022
"github.com/hashicorp/vault/helper/versions"
23+
"github.com/hashicorp/vault/internalshared/configutil"
2124
"github.com/hashicorp/vault/sdk/helper/consts"
2225
"github.com/hashicorp/vault/sdk/logical"
2326
"github.com/hashicorp/vault/vault"
27+
"github.com/stretchr/testify/require"
2428
)
2529

2630
func TestHandler_parseMFAHandler(t *testing.T) {
@@ -884,3 +888,59 @@ func TestHandler_Parse_Form(t *testing.T) {
884888
t.Fatal(diff)
885889
}
886890
}
891+
892+
// TestHandler_MaxRequestSize verifies that a request larger than the
893+
// MaxRequestSize fails
894+
func TestHandler_MaxRequestSize(t *testing.T) {
895+
t.Parallel()
896+
cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{
897+
DefaultHandlerProperties: vault.HandlerProperties{
898+
ListenerConfig: &configutil.Listener{
899+
MaxRequestSize: 1024,
900+
},
901+
},
902+
HandlerFunc: Handler,
903+
NumCores: 1,
904+
})
905+
cluster.Start()
906+
defer cluster.Cleanup()
907+
908+
client := cluster.Cores[0].Client
909+
_, err := client.KVv2("secret").Put(context.Background(), "foo", map[string]interface{}{
910+
"bar": strings.Repeat("a", 1025),
911+
})
912+
913+
require.ErrorContains(t, err, "error parsing JSON")
914+
}
915+
916+
// TestHandler_MaxRequestSize_Memory sets the max request size to 1024 bytes,
917+
// and creates a 1MB request. The test verifies that less than 1MB of memory is
918+
// allocated when the request is sent. This test shouldn't be run in parallel,
919+
// because it modifies GOMAXPROCS
920+
func TestHandler_MaxRequestSize_Memory(t *testing.T) {
921+
ln, addr := TestListener(t)
922+
core, _, token := vault.TestCoreUnsealed(t)
923+
TestServerWithListenerAndProperties(t, ln, addr, core, &vault.HandlerProperties{
924+
Core: core,
925+
ListenerConfig: &configutil.Listener{
926+
Address: addr,
927+
MaxRequestSize: 1024,
928+
},
929+
})
930+
defer ln.Close()
931+
932+
data := bytes.Repeat([]byte{0x1}, 1024*1024)
933+
934+
req, err := http.NewRequest("POST", addr+"/v1/sys/unseal", bytes.NewReader(data))
935+
require.NoError(t, err)
936+
req.Header.Set(consts.AuthHeaderName, token)
937+
938+
client := cleanhttp.DefaultClient()
939+
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
940+
var start, end runtime.MemStats
941+
runtime.GC()
942+
runtime.ReadMemStats(&start)
943+
client.Do(req)
944+
runtime.ReadMemStats(&end)
945+
require.Less(t, end.TotalAlloc-start.TotalAlloc, uint64(1024*1024))
946+
}

http/util.go

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ package http
33
import (
44
"bytes"
55
"context"
6-
"errors"
76
"fmt"
8-
"io/ioutil"
7+
"io"
98
"net"
109
"net/http"
1110
"strings"
1211

12+
"github.com/hashicorp/go-multierror"
1313
"github.com/hashicorp/vault/sdk/logical"
1414

1515
"github.com/hashicorp/vault/helper/namespace"
@@ -35,6 +35,27 @@ var (
3535
adjustResponse = func(core *vault.Core, w http.ResponseWriter, req *logical.Request) {}
3636
)
3737

38+
func wrapMaxRequestSizeHandler(handler http.Handler, props *vault.HandlerProperties) http.Handler {
39+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
40+
var maxRequestSize int64
41+
if props.ListenerConfig != nil {
42+
maxRequestSize = props.ListenerConfig.MaxRequestSize
43+
}
44+
if maxRequestSize == 0 {
45+
maxRequestSize = DefaultMaxRequestSize
46+
}
47+
ctx := r.Context()
48+
originalBody := r.Body
49+
if maxRequestSize > 0 {
50+
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
51+
}
52+
ctx = logical.CreateContextOriginalBody(ctx, originalBody)
53+
r = r.WithContext(ctx)
54+
55+
handler.ServeHTTP(w, r)
56+
})
57+
}
58+
3859
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
3960
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4061
ns, err := namespace.FromContext(r.Context())
@@ -53,14 +74,6 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
5374
}
5475
mountPath := strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path)
5576

56-
// Clone body, so we do not close the request body reader
57-
bodyBytes, err := ioutil.ReadAll(r.Body)
58-
if err != nil {
59-
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
60-
return
61-
}
62-
r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
63-
6477
quotaReq := &quotas.Request{
6578
Type: quotas.TypeRateLimit,
6679
Path: path,
@@ -80,7 +93,18 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
8093
// If any role-based quotas are enabled for this namespace/mount, just
8194
// do the role resolution once here.
8295
if requiresResolveRole {
83-
role := core.DetermineRoleFromLoginRequestFromBytes(r.Context(), mountPath, bodyBytes)
96+
buf := bytes.Buffer{}
97+
teeReader := io.TeeReader(r.Body, &buf)
98+
role := core.DetermineRoleFromLoginRequestFromReader(r.Context(), mountPath, teeReader)
99+
100+
// Reset the body if it was read
101+
if buf.Len() > 0 {
102+
r.Body = io.NopCloser(&buf)
103+
originalBody, ok := logical.ContextOriginalBodyValue(r.Context())
104+
if ok {
105+
r = r.WithContext(logical.CreateContextOriginalBody(r.Context(), newMultiReaderCloser(&buf, originalBody)))
106+
}
107+
}
84108
// add an entry to the context to prevent recalculating request role unnecessarily
85109
r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role))
86110
quotaReq.Role = role
@@ -139,3 +163,25 @@ func parseRemoteIPAddress(r *http.Request) string {
139163

140164
return ip
141165
}
166+
167+
type multiReaderCloser struct {
168+
readers []io.Reader
169+
io.Reader
170+
}
171+
172+
func newMultiReaderCloser(readers ...io.Reader) *multiReaderCloser {
173+
return &multiReaderCloser{
174+
readers: readers,
175+
Reader: io.MultiReader(readers...),
176+
}
177+
}
178+
179+
func (m *multiReaderCloser) Close() error {
180+
var err error
181+
for _, r := range m.readers {
182+
if c, ok := r.(io.Closer); ok {
183+
err = multierror.Append(err, c.Close())
184+
}
185+
}
186+
return err
187+
}

sdk/logical/request.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package logical
33
import (
44
"context"
55
"fmt"
6+
"io"
67
"net/http"
78
"strings"
89
"time"
@@ -398,3 +399,14 @@ type CtxKeyRequestRole struct{}
398399
func (c CtxKeyRequestRole) String() string {
399400
return "request-role"
400401
}
402+
403+
type ctxKeyOriginalBody struct{}
404+
405+
func ContextOriginalBodyValue(ctx context.Context) (io.ReadCloser, bool) {
406+
value, ok := ctx.Value(ctxKeyOriginalBody{}).(io.ReadCloser)
407+
return value, ok
408+
}
409+
410+
func CreateContextOriginalBody(parent context.Context, body io.ReadCloser) context.Context {
411+
return context.WithValue(parent, ctxKeyOriginalBody{}, body)
412+
}

vault/core.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3888,22 +3888,24 @@ func (c *Core) LoadNodeID() (string, error) {
38883888
return hostname, nil
38893889
}
38903890

3891-
// DetermineRoleFromLoginRequestFromBytes will determine the role that should be applied to a quota for a given
3892-
// login request, accepting a byte payload
3893-
func (c *Core) DetermineRoleFromLoginRequestFromBytes(ctx context.Context, mountPoint string, payload []byte) string {
3894-
data := make(map[string]interface{})
3895-
err := jsonutil.DecodeJSON(payload, &data)
3896-
if err != nil {
3897-
// Cannot discern a role from a request we cannot parse
3891+
// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given
3892+
// login request
3893+
func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string {
3894+
c.authLock.RLock()
3895+
defer c.authLock.RUnlock()
3896+
matchingBackend := c.router.MatchingBackend(ctx, mountPoint)
3897+
if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential {
3898+
// Role based quotas do not apply to this request
38983899
return ""
38993900
}
3900-
3901-
return c.DetermineRoleFromLoginRequest(ctx, mountPoint, data)
3901+
return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data)
39023902
}
39033903

3904-
// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given
3905-
// login request
3906-
func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string {
3904+
// DetermineRoleFromLoginRequestFromReader will determine the role that should
3905+
// be applied to a quota for a given login request. The reader will only be
3906+
// consumed if the matching backend for the mount point exists and is a secret
3907+
// backend
3908+
func (c *Core) DetermineRoleFromLoginRequestFromReader(ctx context.Context, mountPoint string, reader io.Reader) string {
39073909
c.authLock.RLock()
39083910
defer c.authLock.RUnlock()
39093911
matchingBackend := c.router.MatchingBackend(ctx, mountPoint)
@@ -3912,6 +3914,17 @@ func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint str
39123914
return ""
39133915
}
39143916

3917+
data := make(map[string]interface{})
3918+
err := jsonutil.DecodeJSONFromReader(reader, &data)
3919+
if err != nil {
3920+
return ""
3921+
}
3922+
return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data)
3923+
}
3924+
3925+
// doResolveRoleLocked does a login and resolve role request on the matching
3926+
// backend. Callers should have a read lock on c.authLock
3927+
func (c *Core) doResolveRoleLocked(ctx context.Context, mountPoint string, matchingBackend logical.Backend, data map[string]interface{}) string {
39153928
resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{
39163929
MountPoint: mountPoint,
39173930
Path: "login",

0 commit comments

Comments
 (0)