@@ -21,15 +21,19 @@ import (
21
21
"crypto/x509"
22
22
"encoding/base64"
23
23
"fmt"
24
- "io/ioutil "
24
+ "io"
25
25
26
26
"encoding/xml"
27
27
28
28
"github.com/beevik/etree"
29
+ rtvalidator "github.com/mattermost/xml-roundtrip-validator"
29
30
"github.com/russellhaering/gosaml2/types"
30
31
dsig "github.com/russellhaering/goxmldsig"
31
32
"github.com/russellhaering/goxmldsig/etreeutils"
32
- rtvalidator "github.com/mattermost/xml-roundtrip-validator"
33
+ )
34
+
35
+ const (
36
+ defaultMaxDecompressedResponseSize = 5 * 1024 * 1024
33
37
)
34
38
35
39
func (sp * SAMLServiceProvider ) validationContext () * dsig.ValidationContext {
@@ -174,7 +178,7 @@ func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
174
178
return fmt .Errorf ("unable to decrypt encrypted assertion: %v" , derr )
175
179
}
176
180
177
- doc , _ , err := parseResponse (raw )
181
+ doc , _ , err := parseResponse (raw , sp . MaximumDecompressedBodySize )
178
182
if err != nil {
179
183
return fmt .Errorf ("unable to create element from decrypted assertion bytes: %v" , err )
180
184
}
@@ -250,17 +254,17 @@ func (sp *SAMLServiceProvider) validateAssertionSignatures(el *etree.Element) er
250
254
}
251
255
}
252
256
253
- //ValidateEncodedResponse both decodes and validates, based on SP
254
- //configuration, an encoded, signed response. It will also appropriately
255
- //decrypt a response if the assertion was encrypted
257
+ // ValidateEncodedResponse both decodes and validates, based on SP
258
+ // configuration, an encoded, signed response. It will also appropriately
259
+ // decrypt a response if the assertion was encrypted
256
260
func (sp * SAMLServiceProvider ) ValidateEncodedResponse (encodedResponse string ) (* types.Response , error ) {
257
261
raw , err := base64 .StdEncoding .DecodeString (encodedResponse )
258
262
if err != nil {
259
263
return nil , err
260
264
}
261
265
262
266
// Parse the raw response
263
- doc , el , err := parseResponse (raw )
267
+ doc , el , err := parseResponse (raw , sp . MaximumDecompressedBodySize )
264
268
if err != nil {
265
269
return nil , err
266
270
}
@@ -330,7 +334,7 @@ func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBase
330
334
331
335
var response * types.UnverifiedBaseResponse
332
336
333
- err = maybeDeflate (raw , func (maybeXML []byte ) error {
337
+ err = maybeDeflate (raw , defaultMaxDecompressedResponseSize , func (maybeXML []byte ) error {
334
338
response = & types.UnverifiedBaseResponse {}
335
339
return xml .Unmarshal (maybeXML , response )
336
340
})
@@ -344,26 +348,37 @@ func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBase
344
348
// maybeDeflate invokes the passed decoder over the passed data. If an error is
345
349
// returned, it then attempts to deflate the passed data before re-invoking
346
350
// the decoder over the deflated data.
347
- func maybeDeflate (data []byte , decoder func ([]byte ) error ) error {
351
+ func maybeDeflate (data []byte , maxSize int64 , decoder func ([]byte ) error ) error {
348
352
err := decoder (data )
349
353
if err == nil {
350
354
return nil
351
355
}
352
356
353
- deflated , err := ioutil .ReadAll (flate .NewReader (bytes .NewReader (data )))
357
+ // Default to 5MB max size
358
+ if maxSize == 0 {
359
+ maxSize = defaultMaxDecompressedResponseSize
360
+ }
361
+
362
+ lr := io .LimitReader (flate .NewReader (bytes .NewReader (data )), maxSize + 1 )
363
+
364
+ deflated , err := io .ReadAll (lr )
354
365
if err != nil {
355
366
return err
356
367
}
357
368
369
+ if int64 (len (deflated )) > maxSize {
370
+ return fmt .Errorf ("deflated response exceeds maximum size of %d bytes" , maxSize )
371
+ }
372
+
358
373
return decoder (deflated )
359
374
}
360
375
361
376
// parseResponse is a helper function that was refactored out so that the XML parsing behavior can be isolated and unit tested
362
- func parseResponse (xml []byte ) (* etree.Document , * etree.Element , error ) {
377
+ func parseResponse (xml []byte , maxSize int64 ) (* etree.Document , * etree.Element , error ) {
363
378
var doc * etree.Document
364
379
var rawXML []byte
365
380
366
- err := maybeDeflate (xml , func (xml []byte ) error {
381
+ err := maybeDeflate (xml , maxSize , func (xml []byte ) error {
367
382
doc = etree .NewDocument ()
368
383
rawXML = xml
369
384
return doc .ReadFromBytes (xml )
@@ -395,7 +410,7 @@ func DecodeUnverifiedLogoutResponse(encodedResponse string) (*types.LogoutRespon
395
410
396
411
var response * types.LogoutResponse
397
412
398
- err = maybeDeflate (raw , func (maybeXML []byte ) error {
413
+ err = maybeDeflate (raw , defaultMaxDecompressedResponseSize , func (maybeXML []byte ) error {
399
414
response = & types.LogoutResponse {}
400
415
return xml .Unmarshal (maybeXML , response )
401
416
})
@@ -413,7 +428,7 @@ func (sp *SAMLServiceProvider) ValidateEncodedLogoutResponsePOST(encodedResponse
413
428
}
414
429
415
430
// Parse the raw response
416
- doc , el , err := parseResponse (raw )
431
+ doc , el , err := parseResponse (raw , sp . MaximumDecompressedBodySize )
417
432
if err != nil {
418
433
return nil , err
419
434
}
0 commit comments