Skip to content

Commit 7c8bfbd

Browse files
edanielsSean-Der
authored andcommitted
Make pc.Close wait on spawned goroutines to close
1 parent f229661 commit 7c8bfbd

File tree

3 files changed

+152
-4
lines changed

3 files changed

+152
-4
lines changed

datachannel.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ type DataChannel struct {
4040
readyState atomic.Value // DataChannelState
4141
bufferedAmountLowThreshold uint64
4242
detachCalled bool
43+
readLoopActive chan struct{}
4344

4445
// The binaryType represents attribute MUST, on getting, return the value to
4546
// which it was last set. On setting, if the new value is either the string
@@ -327,6 +328,7 @@ func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlread
327328
defer d.mu.Unlock()
328329

329330
if !d.api.settingEngine.detach.DataChannels {
331+
d.readLoopActive = make(chan struct{})
330332
go d.readLoop()
331333
}
332334
}
@@ -350,6 +352,7 @@ func (d *DataChannel) onError(err error) {
350352
}
351353

352354
func (d *DataChannel) readLoop() {
355+
defer close(d.readLoopActive)
353356
buffer := make([]byte, dataChannelBufferSize)
354357
for {
355358
n, isString, err := d.dataChannel.ReadDataChannel(buffer)
@@ -449,6 +452,22 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
449452
// Close Closes the DataChannel. It may be called regardless of whether
450453
// the DataChannel object was created by this peer or the remote peer.
451454
func (d *DataChannel) Close() error {
455+
return d.close(false)
456+
}
457+
458+
// Normally, close only stops writes from happening, so waitForReadsDone=true
459+
// will wait for reads to be finished based on underlying SCTP association
460+
// closure or a SCTP reset stream from the other side. This is safe to call
461+
// with waitForReadsDone=true after tearing down a PeerConnection but not
462+
// necessarily before. For example, if you used a vnet and dropped all packets
463+
// right before closing the DataChannel, you'd need never see a reset stream.
464+
func (d *DataChannel) close(waitForReadsDone bool) error {
465+
if waitForReadsDone && d.readLoopActive != nil {
466+
defer func() {
467+
<-d.readLoopActive
468+
}()
469+
}
470+
452471
d.mu.Lock()
453472
haveSctpTransport := d.dataChannel != nil
454473
d.mu.Unlock()

peerconnection.go

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type PeerConnection struct {
5656
idpLoginURL *string
5757

5858
isClosed *atomicBool
59+
isClosedDone chan struct{}
5960
isNegotiationNeeded *atomicBool
6061
updateNegotiationNeededFlagOnEmptyChain *atomicBool
6162

@@ -116,6 +117,7 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection,
116117
ICECandidatePoolSize: 0,
117118
},
118119
isClosed: &atomicBool{},
120+
isClosedDone: make(chan struct{}),
119121
isNegotiationNeeded: &atomicBool{},
120122
updateNegotiationNeededFlagOnEmptyChain: &atomicBool{},
121123
lastOffer: "",
@@ -2044,14 +2046,31 @@ func (pc *PeerConnection) writeRTCP(pkts []rtcp.Packet, _ interceptor.Attributes
20442046
return pc.dtlsTransport.WriteRTCP(pkts)
20452047
}
20462048

2047-
// Close ends the PeerConnection
2049+
// Close ends the PeerConnection.
2050+
// It will make a best effort to wait for all underlying goroutines it spawned to finish,
2051+
// except for cases that would cause deadlocks with itself.
20482052
func (pc *PeerConnection) Close() error {
20492053
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1)
20502054
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2)
20512055
if pc.isClosed.swap(true) {
2056+
// someone else got here first but may still be closing (e.g. via DTLS close_notify)
2057+
<-pc.isClosedDone
20522058
return nil
20532059
}
2060+
defer close(pc.isClosedDone)
20542061

2062+
// Try closing everything and collect the errors
2063+
// Shutdown strategy:
2064+
// 1. Close all data channels.
2065+
// 2. All Conn close by closing their underlying Conn.
2066+
// 3. A Mux stops this chain. It won't close the underlying
2067+
// Conn if one of the endpoints is closed down. To
2068+
// continue the chain the Mux has to be closed.
2069+
pc.sctpTransport.lock.Lock()
2070+
closeErrs := make([]error, 0, 4+len(pc.sctpTransport.dataChannels))
2071+
pc.sctpTransport.lock.Unlock()
2072+
2073+
// canon steps
20552074
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3)
20562075
pc.signalingState.Set(SignalingStateClosed)
20572076

@@ -2061,7 +2080,6 @@ func (pc *PeerConnection) Close() error {
20612080
// 2. A Mux stops this chain. It won't close the underlying
20622081
// Conn if one of the endpoints is closed down. To
20632082
// continue the chain the Mux has to be closed.
2064-
closeErrs := make([]error, 4)
20652083

20662084
closeErrs = append(closeErrs, pc.api.interceptor.Close())
20672085

@@ -2088,7 +2106,6 @@ func (pc *PeerConnection) Close() error {
20882106

20892107
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #7)
20902108
closeErrs = append(closeErrs, pc.dtlsTransport.Stop())
2091-
20922109
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #8, #9, #10)
20932110
if pc.iceTransport != nil {
20942111
closeErrs = append(closeErrs, pc.iceTransport.Stop())
@@ -2097,6 +2114,13 @@ func (pc *PeerConnection) Close() error {
20972114
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11)
20982115
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
20992116

2117+
// non-canon steps
2118+
pc.sctpTransport.lock.Lock()
2119+
for _, d := range pc.sctpTransport.dataChannels {
2120+
closeErrs = append(closeErrs, d.close(true))
2121+
}
2122+
pc.sctpTransport.lock.Unlock()
2123+
21002124
return util.FlattenErrs(closeErrs)
21012125
}
21022126

@@ -2268,8 +2292,11 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
22682292
}
22692293

22702294
pc.dtlsTransport.internalOnCloseHandler = func() {
2271-
pc.log.Info("Closing PeerConnection from DTLS CloseNotify")
2295+
if pc.isClosed.get() {
2296+
return
2297+
}
22722298

2299+
pc.log.Info("Closing PeerConnection from DTLS CloseNotify")
22732300
go func() {
22742301
if pcClosErr := pc.Close(); pcClosErr != nil {
22752302
pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr)

peerconnection_close_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
package webrtc
88

99
import (
10+
"runtime"
11+
"strings"
1012
"testing"
1113
"time"
1214

@@ -179,3 +181,103 @@ func TestPeerConnection_Close_DuringICE(t *testing.T) {
179181
t.Error("pcOffer.Close() Timeout")
180182
}
181183
}
184+
185+
func TestPeerConnection_CloseWithIncomingMessages(t *testing.T) {
186+
// Limit runtime in case of deadlocks
187+
lim := test.TimeOut(time.Second * 20)
188+
defer lim.Stop()
189+
190+
report := CheckRoutinesIntolerant(t)
191+
defer report()
192+
193+
pcOffer, pcAnswer, err := newPair()
194+
if err != nil {
195+
t.Fatal(err)
196+
}
197+
198+
var dcAnswer *DataChannel
199+
answerDataChannelOpened := make(chan struct{})
200+
pcAnswer.OnDataChannel(func(d *DataChannel) {
201+
// Make sure this is the data channel we were looking for. (Not the one
202+
// created in signalPair).
203+
if d.Label() != "data" {
204+
return
205+
}
206+
dcAnswer = d
207+
close(answerDataChannelOpened)
208+
})
209+
210+
dcOffer, err := pcOffer.CreateDataChannel("data", nil)
211+
if err != nil {
212+
t.Fatal(err)
213+
}
214+
215+
offerDataChannelOpened := make(chan struct{})
216+
dcOffer.OnOpen(func() {
217+
close(offerDataChannelOpened)
218+
})
219+
220+
err = signalPair(pcOffer, pcAnswer)
221+
if err != nil {
222+
t.Fatal(err)
223+
}
224+
225+
<-offerDataChannelOpened
226+
<-answerDataChannelOpened
227+
228+
msgNum := 0
229+
dcOffer.OnMessage(func(_ DataChannelMessage) {
230+
t.Log("msg", msgNum)
231+
msgNum++
232+
})
233+
234+
// send 50 messages, then close pcOffer, and then send another 50
235+
for i := 0; i < 100; i++ {
236+
if i == 50 {
237+
err = pcOffer.Close()
238+
if err != nil {
239+
t.Fatal(err)
240+
}
241+
}
242+
_ = dcAnswer.Send([]byte("hello!"))
243+
}
244+
245+
err = pcAnswer.Close()
246+
if err != nil {
247+
t.Fatal(err)
248+
}
249+
}
250+
251+
// CheckRoutinesIntolerant is used to check for leaked go-routines.
252+
// It differs from test.CheckRoutines in that it won't wait at all
253+
// for lingering goroutines. This is helpful for tests that need
254+
// to ensure clean closure of resources.
255+
func CheckRoutinesIntolerant(t *testing.T) func() {
256+
return func() {
257+
routines := getRoutines()
258+
if len(routines) == 0 {
259+
return
260+
}
261+
t.Fatalf("%s: \n%s", "Unexpected routines on test end", strings.Join(routines, "\n\n")) // nolint
262+
}
263+
}
264+
265+
func getRoutines() []string {
266+
buf := make([]byte, 2<<20)
267+
buf = buf[:runtime.Stack(buf, true)]
268+
return filterRoutines(strings.Split(string(buf), "\n\n"))
269+
}
270+
271+
func filterRoutines(routines []string) []string {
272+
result := []string{}
273+
for _, stack := range routines {
274+
if stack == "" || // Empty
275+
strings.Contains(stack, "testing.Main(") || // Tests
276+
strings.Contains(stack, "testing.(*T).Run(") || // Test run
277+
strings.Contains(stack, "getRoutines(") { // This routine
278+
continue
279+
}
280+
result = append(result, stack)
281+
}
282+
return result
283+
}

0 commit comments

Comments
 (0)