Skip to content

Commit 835ac3b

Browse files
sukunrtSean-Der
authored andcommitted
Drop reference to detached datachannels
This allows users of detached datachannels to garbage collect resources associated with the datachannel and the sctp stream. There is no functional change here.
1 parent a8c02b0 commit 835ac3b

File tree

5 files changed

+102
-18
lines changed

5 files changed

+102
-18
lines changed

datachannel.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,6 @@ func (d *DataChannel) ensureOpen() error {
420420
// resulting DataChannel object.
421421
func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
422422
d.mu.Lock()
423-
defer d.mu.Unlock()
424423

425424
if !d.api.settingEngine.detach.DataChannels {
426425
return nil, errDetachNotEnabled
@@ -432,7 +431,28 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
432431

433432
d.detachCalled = true
434433

435-
return d.dataChannel, nil
434+
dataChannel := d.dataChannel
435+
d.mu.Unlock()
436+
437+
// Remove the reference from SCTPTransport so that the datachannel
438+
// can be garbage collected on close
439+
d.sctpTransport.lock.Lock()
440+
n := len(d.sctpTransport.dataChannels)
441+
j := 0
442+
for i := 0; i < n; i++ {
443+
if d == d.sctpTransport.dataChannels[i] {
444+
continue
445+
}
446+
d.sctpTransport.dataChannels[j] = d.sctpTransport.dataChannels[i]
447+
j++
448+
}
449+
for i := j; i < n; i++ {
450+
d.sctpTransport.dataChannels[i] = nil
451+
}
452+
d.sctpTransport.dataChannels = d.sctpTransport.dataChannels[:j]
453+
d.sctpTransport.lock.Unlock()
454+
455+
return dataChannel, nil
436456
}
437457

438458
// Close Closes the DataChannel. It may be called regardless of whether

datachannel_go_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,57 @@ func TestDataChannel_Dial(t *testing.T) {
692692
closePair(t, offerPC, answerPC, done)
693693
})
694694
}
695+
696+
func TestDetachRemovesDatachannelReference(t *testing.T) {
697+
// Use Detach data channels mode
698+
s := SettingEngine{}
699+
s.DetachDataChannels()
700+
api := NewAPI(WithSettingEngine(s))
701+
702+
// Set up two peer connections.
703+
config := Configuration{}
704+
pca, err := api.NewPeerConnection(config)
705+
if err != nil {
706+
t.Fatal(err)
707+
}
708+
pcb, err := api.NewPeerConnection(config)
709+
if err != nil {
710+
t.Fatal(err)
711+
}
712+
713+
defer closePairNow(t, pca, pcb)
714+
715+
dcChan := make(chan *DataChannel, 1)
716+
pcb.OnDataChannel(func(d *DataChannel) {
717+
d.OnOpen(func() {
718+
if _, detachErr := d.Detach(); detachErr != nil {
719+
t.Error(detachErr)
720+
}
721+
722+
dcChan <- d
723+
})
724+
})
725+
726+
if err = signalPair(pca, pcb); err != nil {
727+
t.Fatal(err)
728+
}
729+
730+
attached, err := pca.CreateDataChannel("", nil)
731+
if err != nil {
732+
t.Fatal(err)
733+
}
734+
open := make(chan struct{}, 1)
735+
attached.OnOpen(func() {
736+
open <- struct{}{}
737+
})
738+
<-open
739+
740+
d := <-dcChan
741+
d.sctpTransport.lock.RLock()
742+
defer d.sctpTransport.lock.RUnlock()
743+
for _, dc := range d.sctpTransport.dataChannels[:cap(d.sctpTransport.dataChannels)] {
744+
if dc == d {
745+
t.Errorf("expected sctpTransport to drop reference to datachannel")
746+
}
747+
}
748+
}

peerconnection.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,6 +2018,9 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn
20182018

20192019
pc.sctpTransport.lock.Lock()
20202020
pc.sctpTransport.dataChannels = append(pc.sctpTransport.dataChannels, d)
2021+
if d.ID() != nil {
2022+
pc.sctpTransport.dataChannelIDsUsed[*d.ID()] = struct{}{}
2023+
}
20212024
pc.sctpTransport.dataChannelsRequested++
20222025
pc.sctpTransport.lock.Unlock()
20232026

sctptransport.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ type SCTPTransport struct {
5252

5353
// DataChannels
5454
dataChannels []*DataChannel
55+
dataChannelIDsUsed map[uint16]struct{}
5556
dataChannelsOpened uint32
5657
dataChannelsRequested uint32
5758
dataChannelsAccepted uint32
@@ -65,10 +66,11 @@ type SCTPTransport struct {
6566
// meant to be used together with the basic WebRTC API.
6667
func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
6768
res := &SCTPTransport{
68-
dtlsTransport: dtls,
69-
state: SCTPTransportStateConnecting,
70-
api: api,
71-
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
69+
dtlsTransport: dtls,
70+
state: SCTPTransportStateConnecting,
71+
api: api,
72+
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
73+
dataChannelIDsUsed: make(map[uint16]struct{}),
7274
}
7375

7476
res.updateMessageSize()
@@ -287,6 +289,13 @@ func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
287289
r.lock.Lock()
288290
r.dataChannels = append(r.dataChannels, dc)
289291
r.dataChannelsAccepted++
292+
if dc.ID() != nil {
293+
r.dataChannelIDsUsed[*dc.ID()] = struct{}{}
294+
} else {
295+
// This cannot happen, the constructor for this datachannel in the caller
296+
// takes a pointer to the id.
297+
r.log.Errorf("accepted data channel with no ID")
298+
}
290299
handler := r.onDataChannelHandler
291300
r.lock.Unlock()
292301

@@ -393,21 +402,12 @@ func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **u
393402
r.lock.Lock()
394403
defer r.lock.Unlock()
395404

396-
// Create map of ids so we can compare without double-looping each time.
397-
idsMap := make(map[uint16]struct{}, len(r.dataChannels))
398-
for _, dc := range r.dataChannels {
399-
if dc.ID() == nil {
400-
continue
401-
}
402-
403-
idsMap[*dc.ID()] = struct{}{}
404-
}
405-
406405
for ; id < max-1; id += 2 {
407-
if _, ok := idsMap[id]; ok {
406+
if _, ok := r.dataChannelIDsUsed[id]; ok {
408407
continue
409408
}
410409
*idOut = &id
410+
r.dataChannelIDsUsed[id] = struct{}{}
411411
return nil
412412
}
413413

sctptransport_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@ import "testing"
1010

1111
func TestGenerateDataChannelID(t *testing.T) {
1212
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
13-
ret := &SCTPTransport{dataChannels: []*DataChannel{}}
13+
ret := &SCTPTransport{
14+
dataChannels: []*DataChannel{},
15+
dataChannelIDsUsed: make(map[uint16]struct{}),
16+
}
1417

1518
for i := range ids {
1619
id := ids[i]
1720
ret.dataChannels = append(ret.dataChannels, &DataChannel{id: &id})
21+
ret.dataChannelIDsUsed[id] = struct{}{}
1822
}
1923

2024
return ret
@@ -46,5 +50,8 @@ func TestGenerateDataChannelID(t *testing.T) {
4650
if *idPtr != testCase.result {
4751
t.Errorf("Wrong id: %d expected %d", *idPtr, testCase.result)
4852
}
53+
if _, ok := testCase.s.dataChannelIDsUsed[*idPtr]; !ok {
54+
t.Errorf("expected new id to be added to the map: %d", *idPtr)
55+
}
4956
}
5057
}

0 commit comments

Comments
 (0)