Skip to content

Commit 60eea43

Browse files
committed
Close PeerConnection on DTLS CloseNotify
Resolves #1767 Resolves pion/dtls#151
1 parent ea23dec commit 60eea43

File tree

6 files changed

+90
-22
lines changed

6 files changed

+90
-22
lines changed

dtlstransport.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ type DTLSTransport struct {
4444
state DTLSTransportState
4545
srtpProtectionProfile srtp.ProtectionProfile
4646

47-
onStateChangeHandler func(DTLSTransportState)
47+
onStateChangeHandler func(DTLSTransportState)
48+
internalOnCloseHandler func()
4849

4950
conn *dtls.Conn
5051

@@ -322,6 +323,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
322323

323324
var dtlsConn *dtls.Conn
324325
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
326+
dtlsEndpoint.SetOnClose(t.internalOnCloseHandler)
325327
role, dtlsConfig, err := prepareTransport()
326328
if err != nil {
327329
return err

dtlstransport_test.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717

1818
// An invalid fingerprint MUST cause PeerConnectionState to go to PeerConnectionStateFailed
1919
func TestInvalidFingerprintCausesFailed(t *testing.T) {
20-
lim := test.TimeOut(time.Second * 40)
20+
lim := test.TimeOut(time.Second * 5)
2121
defer lim.Stop()
2222

2323
report := test.CheckRoutines(t)
@@ -46,8 +46,8 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) {
4646
}
4747
})
4848

49-
offerConnectionHasFailed := untilConnectionState(PeerConnectionStateFailed, pcOffer)
50-
answerConnectionHasFailed := untilConnectionState(PeerConnectionStateFailed, pcAnswer)
49+
offerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcOffer)
50+
answerConnectionHasClosed := untilConnectionState(PeerConnectionStateClosed, pcAnswer)
5151

5252
if _, err = pcOffer.CreateDataChannel("unusedDataChannel", nil); err != nil {
5353
t.Fatal(err)
@@ -89,13 +89,17 @@ func TestInvalidFingerprintCausesFailed(t *testing.T) {
8989
t.Fatal("timed out waiting to receive offer")
9090
}
9191

92-
offerConnectionHasFailed.Wait()
93-
answerConnectionHasFailed.Wait()
92+
offerConnectionHasClosed.Wait()
93+
answerConnectionHasClosed.Wait()
9494

95-
assert.Equal(t, pcOffer.SCTP().Transport().State(), DTLSTransportStateFailed)
95+
if pcOffer.SCTP().Transport().State() != DTLSTransportStateClosed && pcOffer.SCTP().Transport().State() != DTLSTransportStateFailed {
96+
t.Fail()
97+
}
9698
assert.Nil(t, pcOffer.SCTP().Transport().conn)
9799

98-
assert.Equal(t, pcAnswer.SCTP().Transport().State(), DTLSTransportStateFailed)
100+
if pcAnswer.SCTP().Transport().State() != DTLSTransportStateClosed && pcAnswer.SCTP().Transport().State() != DTLSTransportStateFailed {
101+
t.Fail()
102+
}
99103
assert.Nil(t, pcAnswer.SCTP().Transport().conn)
100104
}
101105

internal/mux/endpoint.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@ import (
1515

1616
// Endpoint implements net.Conn. It is used to read muxed packets.
1717
type Endpoint struct {
18-
mux *Mux
19-
buffer *packetio.Buffer
18+
mux *Mux
19+
buffer *packetio.Buffer
20+
onClose func()
2021
}
2122

2223
// Close unregisters the endpoint from the Mux
2324
func (e *Endpoint) Close() (err error) {
24-
err = e.close()
25-
if err != nil {
25+
if e.onClose != nil {
26+
e.onClose()
27+
}
28+
29+
if err = e.close(); err != nil {
2630
return err
2731
}
2832

@@ -76,3 +80,9 @@ func (e *Endpoint) SetReadDeadline(time.Time) error {
7680
func (e *Endpoint) SetWriteDeadline(time.Time) error {
7781
return nil
7882
}
83+
84+
// SetOnClose is a user set callback that
85+
// will be executed when `Close` is called
86+
func (e *Endpoint) SetOnClose(onClose func()) {
87+
e.onClose = onClose
88+
}

peerconnection.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,16 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
22612261
return
22622262
}
22632263

2264+
pc.dtlsTransport.internalOnCloseHandler = func() {
2265+
pc.log.Info("Closing PeerConnection from DTLS CloseNotify")
2266+
2267+
go func() {
2268+
if pcClosErr := pc.Close(); pcClosErr != nil {
2269+
pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr)
2270+
}
2271+
}()
2272+
}
2273+
22642274
// Start the dtls transport
22652275
err = pc.dtlsTransport.Start(DTLSParameters{
22662276
Role: dtlsRole,

peerconnection_media_test.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/pion/rtp"
2525
"github.com/pion/sdp/v3"
2626
"github.com/pion/transport/v3/test"
27+
"github.com/pion/transport/v3/vnet"
2728
"github.com/pion/webrtc/v3/pkg/media"
2829
"github.com/stretchr/testify/assert"
2930
"github.com/stretchr/testify/require"
@@ -329,10 +330,15 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
329330
m := &MediaEngine{}
330331
assert.NoError(t, m.RegisterDefaultCodecs())
331332

332-
pcOffer, pcAnswer, err := NewAPI(WithSettingEngine(s), WithMediaEngine(m)).newPair(Configuration{})
333-
if err != nil {
334-
t.Fatal(err)
335-
}
333+
pcOffer, pcAnswer, wan := createVNetPair(t)
334+
335+
keepPackets := &atomicBool{}
336+
keepPackets.set(true)
337+
338+
// Add a filter that monitors the traffic on the router
339+
wan.AddChunkFilter(func(c vnet.Chunk) bool {
340+
return keepPackets.get()
341+
})
336342

337343
vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2")
338344
if err != nil {
@@ -360,14 +366,11 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
360366
time.Sleep(time.Second)
361367
}
362368

363-
if pcCloseErr := pcAnswer.Close(); pcCloseErr != nil {
364-
haveDisconnected <- pcCloseErr
365-
}
369+
keepPackets.set(false)
366370
}
367371
})
368372

369-
err = signalPair(pcOffer, pcAnswer)
370-
if err != nil {
373+
if err = signalPair(pcOffer, pcAnswer); err != nil {
371374
t.Fatal(err)
372375
}
373376

@@ -383,7 +386,8 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) {
383386
}
384387
}
385388

386-
assert.NoError(t, pcOffer.Close())
389+
assert.NoError(t, wan.Stop())
390+
closePairNow(t, pcOffer, pcAnswer)
387391
}
388392

389393
type undeclaredSsrcLogger struct{ unhandledSimulcastError chan struct{} }

peerconnection_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,3 +754,41 @@ func TestTransportChain(t *testing.T) {
754754

755755
closePairNow(t, offer, answer)
756756
}
757+
758+
// Assert that the PeerConnection closes via DTLS (and not ICE)
759+
func TestDTLSClose(t *testing.T) {
760+
lim := test.TimeOut(time.Second * 10)
761+
defer lim.Stop()
762+
763+
report := test.CheckRoutines(t)
764+
defer report()
765+
766+
pcOffer, pcAnswer, err := newPair()
767+
assert.NoError(t, err)
768+
769+
_, err = pcOffer.AddTransceiverFromKind(RTPCodecTypeVideo)
770+
assert.NoError(t, err)
771+
772+
peerConnectionsConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer)
773+
774+
offer, err := pcOffer.CreateOffer(nil)
775+
assert.NoError(t, err)
776+
777+
offerGatheringComplete := GatheringCompletePromise(pcOffer)
778+
assert.NoError(t, pcOffer.SetLocalDescription(offer))
779+
<-offerGatheringComplete
780+
781+
assert.NoError(t, pcAnswer.SetRemoteDescription(*pcOffer.LocalDescription()))
782+
783+
answer, err := pcAnswer.CreateAnswer(nil)
784+
assert.NoError(t, err)
785+
786+
answerGatheringComplete := GatheringCompletePromise(pcAnswer)
787+
assert.NoError(t, pcAnswer.SetLocalDescription(answer))
788+
<-answerGatheringComplete
789+
790+
assert.NoError(t, pcOffer.SetRemoteDescription(*pcAnswer.LocalDescription()))
791+
792+
peerConnectionsConnected.Wait()
793+
assert.NoError(t, pcOffer.Close())
794+
}

0 commit comments

Comments
 (0)