Skip to content

Commit e4ff415

Browse files
authored
Support DataChannel messages larger then MaxUint16
SCTP now internally can handle larger messages Resolves #2712
1 parent 98a0025 commit e4ff415

12 files changed

+264
-89
lines changed

constants.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33

44
package webrtc
55

6-
import "github.com/pion/dtls/v3"
6+
import (
7+
"math"
8+
9+
"github.com/pion/dtls/v3"
10+
)
711

812
const (
913
// default as the standard ethernet MTU
@@ -19,6 +23,15 @@ const (
1923
// If the total amount of incoming SSRCes exceeds this new requests will be ignored.
2024
simulcastMaxProbeRoutines = 25
2125

26+
// Default Max SCTP Message Size is the largest single DataChannel
27+
// message we can send or accept. This default was chosen to match FireFox.
28+
defaultMaxSCTPMessageSize = 1073741823
29+
30+
// If a DataChannel Max Message Size isn't declared by the Remote(max-message-size)
31+
// this is the value we default to. This value was chosen because it was the behavior
32+
// of Pion before max-message-size was implemented.
33+
sctpMaxMessageSizeUnsetValue = math.MaxUint16
34+
2235
mediaSectionApplication = "application"
2336

2437
sdpAttributeRid = "rid"

datachannel.go

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13-
"math"
1413
"sync"
1514
"sync/atomic"
1615
"time"
@@ -20,7 +19,6 @@ import (
2019
"github.com/pion/webrtc/v4/pkg/rtcerr"
2120
)
2221

23-
const dataChannelBufferSize = math.MaxUint16 // message size limit for Chromium
2422
var errSCTPNotEstablished = errors.New("SCTP not established")
2523

2624
// DataChannel represents a WebRTC DataChannel
@@ -404,10 +402,24 @@ func (d *DataChannel) readLoop() {
404402
d.mu.Unlock()
405403
defer close(readLoopActive)
406404
}()
407-
buffer := make([]byte, dataChannelBufferSize)
405+
406+
buffer := make([]byte, sctpMaxMessageSizeUnsetValue)
408407
for {
409408
n, isString, err := d.dataChannel.ReadDataChannel(buffer)
410409
if err != nil {
410+
if errors.Is(err, io.ErrShortBuffer) {
411+
if int64(n) < int64(d.api.settingEngine.getSCTPMaxMessageSize()) {
412+
buffer = append(buffer, make([]byte, len(buffer))...) // nolint
413+
414+
continue
415+
}
416+
417+
d.log.Errorf(
418+
"Incoming DataChannel message larger then Max Message size %v",
419+
d.api.settingEngine.getSCTPMaxMessageSize(),
420+
)
421+
}
422+
411423
d.setReadyState(DataChannelStateClosed)
412424
if !errors.Is(err, io.EOF) {
413425
d.onError(err)
@@ -417,11 +429,10 @@ func (d *DataChannel) readLoop() {
417429
return
418430
}
419431

420-
msg := DataChannelMessage{Data: make([]byte, n), IsString: isString}
421-
copy(msg.Data, buffer[:n])
422-
423-
// NB: Why was DataChannelMessage not passed as a pointer value?
424-
d.onMessage(msg) // nolint:staticcheck
432+
d.onMessage(DataChannelMessage{
433+
Data: append([]byte{}, buffer[:n]...),
434+
IsString: isString,
435+
})
425436
}
426437
}
427438

datachannel_go_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package webrtc
88

99
import (
1010
"bytes"
11+
"context"
1112
"crypto/rand"
1213
"encoding/binary"
1314
"io"
@@ -851,3 +852,38 @@ func TestDataChannel_DetachErrors(t *testing.T) {
851852
assert.NoError(t, answer.Close())
852853
})
853854
}
855+
856+
func TestDataChannelMessageSize(t *testing.T) {
857+
offerPC, answerPC, err := newPair()
858+
assert.NoError(t, err)
859+
860+
dc, err := offerPC.CreateDataChannel("", nil)
861+
assert.NoError(t, err)
862+
863+
answerDataChannelMessages := make(chan []byte)
864+
answerPC.OnDataChannel(func(d *DataChannel) {
865+
d.OnMessage(func(m DataChannelMessage) {
866+
answerDataChannelMessages <- m.Data
867+
})
868+
})
869+
870+
assert.NoError(t, signalPair(offerPC, answerPC))
871+
872+
messagesSent, messagesSentCancel := context.WithCancel(context.Background())
873+
dc.OnOpen(func() {
874+
for i := 0; i <= 10; i++ {
875+
outboundMessage := make([]byte, sctpMaxMessageSizeUnsetValue*i)
876+
_, err := rand.Read(outboundMessage)
877+
assert.NoError(t, err)
878+
879+
assert.NoError(t, dc.Send(outboundMessage))
880+
inboundMessage := <-answerDataChannelMessages
881+
882+
assert.Equal(t, outboundMessage, inboundMessage)
883+
}
884+
messagesSentCancel()
885+
})
886+
887+
<-messagesSent.Done()
888+
closePairNow(t, offerPC, answerPC)
889+
}

datachannel_test.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,7 @@ func TestDataChannelParameters(t *testing.T) { //nolint:cyclop
510510
})
511511

512512
go func() {
513-
for {
514-
if seenAnswerMessage.get() && seenOfferMessage.get() {
515-
break
516-
}
517-
513+
for seenAnswerMessage.get() && seenOfferMessage.get() {
518514
if offerDatachannel.ReadyState() == DataChannelStateOpen {
519515
assert.NoError(t, offerDatachannel.SendText(expectedMessage))
520516
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ require (
1111
github.com/pion/randutil v0.1.0
1212
github.com/pion/rtcp v1.2.15
1313
github.com/pion/rtp v1.8.12
14-
github.com/pion/sctp v1.8.36
14+
github.com/pion/sctp v1.8.37
1515
github.com/pion/sdp/v3 v3.0.10
1616
github.com/pion/srtp/v3 v3.0.4
1717
github.com/pion/stun/v3 v3.0.0

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo=
5353
github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0=
5454
github.com/pion/rtp v1.8.12 h1:nsKs8Wi0jQyBFHU3qmn/OvtZrhktVfJY0vRxwACsL5U=
5555
github.com/pion/rtp v1.8.12/go.mod h1:8uMBJj32Pa1wwx8Fuv/AsFhn8jsgw+3rUC2PfoBZ8p4=
56-
github.com/pion/sctp v1.8.36 h1:owNudmnz1xmhfYje5L/FCav3V9wpPRePHle3Zi+P+M0=
57-
github.com/pion/sctp v1.8.36/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE=
56+
github.com/pion/sctp v1.8.37 h1:ZDmGPtRPX9mKCiVXtMbTWybFw3z/hVKAZgU81wcOrqs=
57+
github.com/pion/sctp v1.8.37/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE=
5858
github.com/pion/sdp/v3 v3.0.10 h1:6MChLE/1xYB+CjumMw+gZ9ufp2DPApuVSnDT8t5MIgA=
5959
github.com/pion/sdp/v3 v3.0.10/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E=
6060
github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M=

peerconnection.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,10 +1563,10 @@ func (pc *PeerConnection) startRTPSenders(currentTransceivers []*RTPTransceiver)
15631563
}
15641564

15651565
// Start SCTP subsystem.
1566-
func (pc *PeerConnection) startSCTP() {
1566+
func (pc *PeerConnection) startSCTP(maxMessageSize uint32) {
15671567
// Start sctp
15681568
if err := pc.sctpTransport.Start(SCTPCapabilities{
1569-
MaxMessageSize: 0,
1569+
MaxMessageSize: maxMessageSize,
15701570
}); err != nil {
15711571
pc.log.Warnf("Failed to start SCTP: %s", err)
15721572
if err = pc.sctpTransport.Stop(); err != nil {
@@ -2625,8 +2625,8 @@ func (pc *PeerConnection) startRTP(
26252625
}
26262626

26272627
pc.startRTPReceivers(remoteDesc, currentTransceivers)
2628-
if haveApplicationMediaSection(remoteDesc.parsed) {
2629-
pc.startSCTP()
2628+
if d := haveDataChannel(remoteDesc); d != nil {
2629+
pc.startSCTP(getMaxMessageSize(d))
26302630
}
26312631
}
26322632

@@ -2718,6 +2718,7 @@ func (pc *PeerConnection) generateUnmatchedSDP(
27182718
mediaSections,
27192719
pc.ICEGatheringState(),
27202720
nil,
2721+
pc.api.settingEngine.getSCTPMaxMessageSize(),
27212722
)
27222723
}
27232724

@@ -2884,6 +2885,7 @@ func (pc *PeerConnection) generateMatchedSDP(
28842885
mediaSections,
28852886
pc.ICEGatheringState(),
28862887
bundleGroup,
2888+
pc.api.settingEngine.getSCTPMaxMessageSize(),
28872889
)
28882890
}
28892891

sctptransport.go

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ package webrtc
99
import (
1010
"errors"
1111
"io"
12-
"math"
1312
"sync"
1413
"time"
1514

@@ -34,10 +33,6 @@ type SCTPTransport struct {
3433
// so we need a dedicated field
3534
isStarted bool
3635

37-
// MaxMessageSize represents the maximum size of data that can be passed to
38-
// DataChannel's send() method.
39-
maxMessageSize float64
40-
4136
// MaxChannels represents the maximum amount of DataChannel's that can
4237
// be used simultaneously.
4338
maxChannels *uint16
@@ -74,7 +69,6 @@ func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
7469
dataChannelIDsUsed: make(map[uint16]struct{}),
7570
}
7671

77-
res.updateMessageSize()
7872
res.updateMaxChannels()
7973

8074
return res
@@ -90,20 +84,30 @@ func (r *SCTPTransport) Transport() *DTLSTransport {
9084

9185
// GetCapabilities returns the SCTPCapabilities of the SCTPTransport.
9286
func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
87+
var maxMessageSize uint32
88+
if a := r.association(); a != nil {
89+
maxMessageSize = a.MaxMessageSize()
90+
}
91+
9392
return SCTPCapabilities{
94-
MaxMessageSize: 0,
93+
MaxMessageSize: maxMessageSize,
9594
}
9695
}
9796

9897
// Start the SCTPTransport. Since both local and remote parties must mutually
9998
// create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
10099
// a connection over SCTP.
101-
func (r *SCTPTransport) Start(_ SCTPCapabilities) error {
100+
func (r *SCTPTransport) Start(capabilities SCTPCapabilities) error {
102101
if r.isStarted {
103102
return nil
104103
}
105104
r.isStarted = true
106105

106+
maxMessageSize := capabilities.MaxMessageSize
107+
if maxMessageSize == 0 {
108+
maxMessageSize = sctpMaxMessageSizeUnsetValue
109+
}
110+
107111
dtlsTransport := r.Transport()
108112
if dtlsTransport == nil || dtlsTransport.conn == nil {
109113
return errSCTPTransportDTLS
@@ -115,6 +119,7 @@ func (r *SCTPTransport) Start(_ SCTPCapabilities) error {
115119
LoggerFactory: r.api.settingEngine.LoggerFactory,
116120
RTOMax: float64(r.api.settingEngine.sctp.rtoMax) / float64(time.Millisecond),
117121
BlockWrite: r.api.settingEngine.detach.DataChannels && r.api.settingEngine.dataChannelBlockWrite,
122+
MaxMessageSize: maxMessageSize,
118123
})
119124
if err != nil {
120125
return err
@@ -344,36 +349,6 @@ func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
344349
return
345350
}
346351

347-
func (r *SCTPTransport) updateMessageSize() {
348-
r.lock.Lock()
349-
defer r.lock.Unlock()
350-
351-
var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758
352-
var canSendSize float64 = 65536 // pion/webrtc#758
353-
354-
r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize)
355-
}
356-
357-
func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 {
358-
switch {
359-
case remoteMaxMessageSize == 0 &&
360-
canSendSize == 0:
361-
return math.Inf(1)
362-
363-
case remoteMaxMessageSize == 0:
364-
return canSendSize
365-
366-
case canSendSize == 0:
367-
return remoteMaxMessageSize
368-
369-
case canSendSize > remoteMaxMessageSize:
370-
return remoteMaxMessageSize
371-
372-
default:
373-
return canSendSize
374-
}
375-
}
376-
377352
func (r *SCTPTransport) updateMaxChannels() {
378353
val := sctpMaxChannels
379354
r.maxChannels = &val

0 commit comments

Comments
 (0)