Skip to content

Commit c8fa88f

Browse files
committed
Make pc.Close wait on spawned goroutines to close
1 parent 68f19e2 commit c8fa88f

File tree

3 files changed

+148
-3
lines changed

3 files changed

+148
-3
lines changed

datachannel.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ type DataChannel struct {
4040
readyState atomic.Value // DataChannelState
4141
bufferedAmountLowThreshold uint64
4242
detachCalled bool
43+
readLoopActive chan struct{}
4344

4445
// The binaryType represents attribute MUST, on getting, return the value to
4546
// which it was last set. On setting, if the new value is either the string
@@ -327,6 +328,7 @@ func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlread
327328
defer d.mu.Unlock()
328329

329330
if !d.api.settingEngine.detach.DataChannels {
331+
d.readLoopActive = make(chan struct{})
330332
go d.readLoop()
331333
}
332334
}
@@ -356,6 +358,7 @@ var rlBufPool = sync.Pool{New: func() interface{} {
356358
}}
357359

358360
func (d *DataChannel) readLoop() {
361+
defer close(d.readLoopActive)
359362
for {
360363
buffer := rlBufPool.Get().([]byte) //nolint:forcetypeassert
361364
n, isString, err := d.dataChannel.ReadDataChannel(buffer)
@@ -438,6 +441,22 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
438441
// Close Closes the DataChannel. It may be called regardless of whether
439442
// the DataChannel object was created by this peer or the remote peer.
440443
func (d *DataChannel) Close() error {
444+
return d.close(false)
445+
}
446+
447+
// Normally, close only stops writes from happening, so waitForReadsDone=true
448+
// will wait for reads to be finished based on underlying SCTP association
449+
// closure or a SCTP reset stream from the other side. This is safe to call
450+
// with waitForReadsDone=true after tearing down a PeerConnection but not
451+
// necessarily before. For example, if you used a vnet and dropped all packets
452+
// right before closing the DataChannel, you'd need never see a reset stream.
453+
func (d *DataChannel) close(waitForReadsDone bool) error {
454+
if waitForReadsDone && d.readLoopActive != nil {
455+
defer func() {
456+
<-d.readLoopActive
457+
}()
458+
}
459+
441460
d.mu.Lock()
442461
haveSctpTransport := d.dataChannel != nil
443462
d.mu.Unlock()

peerconnection.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type PeerConnection struct {
5656
idpLoginURL *string
5757

5858
isClosed *atomicBool
59+
isClosedDone chan struct{}
5960
isNegotiationNeeded *atomicBool
6061
updateNegotiationNeededFlagOnEmptyChain *atomicBool
6162

@@ -127,6 +128,7 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection,
127128
ICECandidatePoolSize: 0,
128129
},
129130
isClosed: &atomicBool{},
131+
isClosedDone: make(chan struct{}),
130132
isNegotiationNeeded: &atomicBool{},
131133
updateNegotiationNeededFlagOnEmptyChain: &atomicBool{},
132134
lastOffer: "",
@@ -2034,14 +2036,31 @@ func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes
20342036
return pc.dtlsTransport.WriteRTCP(pkts)
20352037
}
20362038

2037-
// Close ends the PeerConnection
2039+
// Close ends the PeerConnection.
2040+
// It will make a best effort to wait for all underlying goroutines it spawned to finish,
2041+
// except for cases that would cause deadlocks with itself.
20382042
func (pc *PeerConnection) Close() error {
20392043
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1)
20402044
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2)
20412045
if pc.isClosed.swap(true) {
2046+
// someone else got here first but may still be closing (e.g. via DTLS close_notify)
2047+
<-pc.isClosedDone
20422048
return nil
20432049
}
2050+
defer close(pc.isClosedDone)
20442051

2052+
// Try closing everything and collect the errors
2053+
// Shutdown strategy:
2054+
// 1. Close all data channels.
2055+
// 2. All Conn close by closing their underlying Conn.
2056+
// 3. A Mux stops this chain. It won't close the underlying
2057+
// Conn if one of the endpoints is closed down. To
2058+
// continue the chain the Mux has to be closed.
2059+
pc.sctpTransport.lock.Lock()
2060+
closeErrs := make([]error, 0, 4+len(pc.sctpTransport.dataChannels))
2061+
pc.sctpTransport.lock.Unlock()
2062+
2063+
// canon steps
20452064
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3)
20462065
pc.signalingState.Set(SignalingStateClosed)
20472066

@@ -2051,7 +2070,6 @@ func (pc *PeerConnection) Close() error {
20512070
// 2. A Mux stops this chain. It won't close the underlying
20522071
// Conn if one of the endpoints is closed down. To
20532072
// continue the chain the Mux has to be closed.
2054-
closeErrs := make([]error, 4)
20552073

20562074
closeErrs = append(closeErrs, pc.api.interceptor.Close())
20572075

@@ -2078,7 +2096,6 @@ func (pc *PeerConnection) Close() error {
20782096

20792097
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #7)
20802098
closeErrs = append(closeErrs, pc.dtlsTransport.Stop())
2081-
20822099
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #8, #9, #10)
20832100
if pc.iceTransport != nil {
20842101
closeErrs = append(closeErrs, pc.iceTransport.Stop())
@@ -2087,6 +2104,13 @@ func (pc *PeerConnection) Close() error {
20872104
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11)
20882105
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
20892106

2107+
// non-canon steps
2108+
pc.sctpTransport.lock.Lock()
2109+
for _, d := range pc.sctpTransport.dataChannels {
2110+
closeErrs = append(closeErrs, d.close(true))
2111+
}
2112+
pc.sctpTransport.lock.Unlock()
2113+
20902114
return util.FlattenErrs(closeErrs)
20912115
}
20922116

peerconnection_close_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
package webrtc
88

99
import (
10+
"runtime"
11+
"strings"
1012
"testing"
1113
"time"
1214

@@ -179,3 +181,103 @@ func TestPeerConnection_Close_DuringICE(t *testing.T) {
179181
t.Error("pcOffer.Close() Timeout")
180182
}
181183
}
184+
185+
func TestPeerConnection_CloseWithIncomingMessages(t *testing.T) {
186+
// Limit runtime in case of deadlocks
187+
lim := test.TimeOut(time.Second * 20)
188+
defer lim.Stop()
189+
190+
report := CheckRoutinesIntolerant(t)
191+
defer report()
192+
193+
pcOffer, pcAnswer, err := newPair()
194+
if err != nil {
195+
t.Fatal(err)
196+
}
197+
198+
var dcAnswer *DataChannel
199+
answerDataChannelOpened := make(chan struct{})
200+
pcAnswer.OnDataChannel(func(d *DataChannel) {
201+
// Make sure this is the data channel we were looking for. (Not the one
202+
// created in signalPair).
203+
if d.Label() != "data" {
204+
return
205+
}
206+
dcAnswer = d
207+
close(answerDataChannelOpened)
208+
})
209+
210+
dcOffer, err := pcOffer.CreateDataChannel("data", nil)
211+
if err != nil {
212+
t.Fatal(err)
213+
}
214+
215+
offerDataChannelOpened := make(chan struct{})
216+
dcOffer.OnOpen(func() {
217+
close(offerDataChannelOpened)
218+
})
219+
220+
err = signalPair(pcOffer, pcAnswer)
221+
if err != nil {
222+
t.Fatal(err)
223+
}
224+
225+
<-offerDataChannelOpened
226+
<-answerDataChannelOpened
227+
228+
msgNum := 0
229+
dcOffer.OnMessage(func(_ DataChannelMessage) {
230+
t.Log("msg", msgNum)
231+
msgNum++
232+
})
233+
234+
// send 50 messages, then close pcOffer, and then send another 50
235+
for i := 0; i < 100; i++ {
236+
if i == 50 {
237+
err = pcOffer.Close()
238+
if err != nil {
239+
t.Fatal(err)
240+
}
241+
}
242+
_ = dcAnswer.Send([]byte("hello!"))
243+
}
244+
245+
err = pcAnswer.Close()
246+
if err != nil {
247+
t.Fatal(err)
248+
}
249+
}
250+
251+
// CheckRoutinesIntolerant is used to check for leaked go-routines.
252+
// It differs from test.CheckRoutines in that it won't wait at all
253+
// for lingering goroutines. This is helpful for tests that need
254+
// to ensure clean closure of resources.
255+
func CheckRoutinesIntolerant(t *testing.T) func() {
256+
return func() {
257+
routines := getRoutines()
258+
if len(routines) == 0 {
259+
return
260+
}
261+
t.Fatalf("%s: \n%s", "Unexpected routines on test end", strings.Join(routines, "\n\n")) // nolint
262+
}
263+
}
264+
265+
func getRoutines() []string {
266+
buf := make([]byte, 2<<20)
267+
buf = buf[:runtime.Stack(buf, true)]
268+
return filterRoutines(strings.Split(string(buf), "\n\n"))
269+
}
270+
271+
func filterRoutines(routines []string) []string {
272+
result := []string{}
273+
for _, stack := range routines {
274+
if stack == "" || // Empty
275+
strings.Contains(stack, "testing.Main(") || // Tests
276+
strings.Contains(stack, "testing.(*T).Run(") || // Test run
277+
strings.Contains(stack, "getRoutines(") { // This routine
278+
continue
279+
}
280+
result = append(result, stack)
281+
}
282+
return result
283+
}

0 commit comments

Comments
 (0)