1// +build !js 2 3package webrtc 4 5import ( 6 "io" 7 "testing" 8 "time" 9 10 "github.com/pion/transport/test" 11 "github.com/pion/webrtc/v3/internal/util" 12 "github.com/stretchr/testify/assert" 13) 14 15func TestDataChannel_ORTCE2E(t *testing.T) { 16 // Limit runtime in case of deadlocks 17 lim := test.TimeOut(time.Second * 20) 18 defer lim.Stop() 19 20 report := test.CheckRoutines(t) 21 defer report() 22 23 stackA, stackB, err := newORTCPair() 24 if err != nil { 25 t.Fatal(err) 26 } 27 28 awaitSetup := make(chan struct{}) 29 awaitString := make(chan struct{}) 30 awaitBinary := make(chan struct{}) 31 stackB.sctp.OnDataChannel(func(d *DataChannel) { 32 close(awaitSetup) 33 34 d.OnMessage(func(msg DataChannelMessage) { 35 if msg.IsString { 36 close(awaitString) 37 } else { 38 close(awaitBinary) 39 } 40 }) 41 }) 42 43 err = signalORTCPair(stackA, stackB) 44 if err != nil { 45 t.Fatal(err) 46 } 47 48 var id uint16 = 1 49 dcParams := &DataChannelParameters{ 50 Label: "Foo", 51 ID: &id, 52 } 53 channelA, err := stackA.api.NewDataChannel(stackA.sctp, dcParams) 54 if err != nil { 55 t.Fatal(err) 56 } 57 58 <-awaitSetup 59 60 err = channelA.SendText("ABC") 61 if err != nil { 62 t.Fatal(err) 63 } 64 err = channelA.Send([]byte("ABC")) 65 if err != nil { 66 t.Fatal(err) 67 } 68 <-awaitString 69 <-awaitBinary 70 71 err = stackA.close() 72 if err != nil { 73 t.Fatal(err) 74 } 75 76 err = stackB.close() 77 if err != nil { 78 t.Fatal(err) 79 } 80 81 // attempt to send when channel is closed 82 err = channelA.Send([]byte("ABC")) 83 assert.Error(t, err) 84 assert.Equal(t, io.ErrClosedPipe, err) 85 86 err = channelA.SendText("test") 87 assert.Error(t, err) 88 assert.Equal(t, io.ErrClosedPipe, err) 89 90 err = channelA.ensureOpen() 91 assert.Error(t, err) 92 assert.Equal(t, io.ErrClosedPipe, err) 93} 94 95type testORTCStack struct { 96 api *API 97 gatherer *ICEGatherer 98 ice *ICETransport 99 dtls *DTLSTransport 100 sctp *SCTPTransport 101} 102 103func (s *testORTCStack) setSignal(sig *testORTCSignal, isOffer bool) error { 104 iceRole := ICERoleControlled 105 if isOffer { 106 iceRole = ICERoleControlling 107 } 108 109 err := s.ice.SetRemoteCandidates(sig.ICECandidates) 110 if err != nil { 111 return err 112 } 113 114 // Start the ICE transport 115 err = s.ice.Start(nil, sig.ICEParameters, &iceRole) 116 if err != nil { 117 return err 118 } 119 120 // Start the DTLS transport 121 err = s.dtls.Start(sig.DTLSParameters) 122 if err != nil { 123 return err 124 } 125 126 // Start the SCTP transport 127 err = s.sctp.Start(sig.SCTPCapabilities) 128 if err != nil { 129 return err 130 } 131 132 return nil 133} 134 135func (s *testORTCStack) getSignal() (*testORTCSignal, error) { 136 gatherFinished := make(chan struct{}) 137 s.gatherer.OnLocalCandidate(func(i *ICECandidate) { 138 if i == nil { 139 close(gatherFinished) 140 } 141 }) 142 143 if err := s.gatherer.Gather(); err != nil { 144 return nil, err 145 } 146 147 <-gatherFinished 148 iceCandidates, err := s.gatherer.GetLocalCandidates() 149 if err != nil { 150 return nil, err 151 } 152 153 iceParams, err := s.gatherer.GetLocalParameters() 154 if err != nil { 155 return nil, err 156 } 157 158 dtlsParams, err := s.dtls.GetLocalParameters() 159 if err != nil { 160 return nil, err 161 } 162 163 sctpCapabilities := s.sctp.GetCapabilities() 164 165 return &testORTCSignal{ 166 ICECandidates: iceCandidates, 167 ICEParameters: iceParams, 168 DTLSParameters: dtlsParams, 169 SCTPCapabilities: sctpCapabilities, 170 }, nil 171} 172 173func (s *testORTCStack) close() error { 174 var closeErrs []error 175 176 if err := s.sctp.Stop(); err != nil { 177 closeErrs = append(closeErrs, err) 178 } 179 180 if err := s.ice.Stop(); err != nil { 181 closeErrs = append(closeErrs, err) 182 } 183 184 return util.FlattenErrs(closeErrs) 185} 186 187type testORTCSignal struct { 188 ICECandidates []ICECandidate `json:"iceCandidates"` 189 ICEParameters ICEParameters `json:"iceParameters"` 190 DTLSParameters DTLSParameters `json:"dtlsParameters"` 191 SCTPCapabilities SCTPCapabilities `json:"sctpCapabilities"` 192} 193 194func newORTCPair() (stackA *testORTCStack, stackB *testORTCStack, err error) { 195 sa, err := newORTCStack() 196 if err != nil { 197 return nil, nil, err 198 } 199 200 sb, err := newORTCStack() 201 if err != nil { 202 return nil, nil, err 203 } 204 205 return sa, sb, nil 206} 207 208func newORTCStack() (*testORTCStack, error) { 209 // Create an API object 210 api := NewAPI() 211 212 // Create the ICE gatherer 213 gatherer, err := api.NewICEGatherer(ICEGatherOptions{}) 214 if err != nil { 215 return nil, err 216 } 217 218 // Construct the ICE transport 219 ice := api.NewICETransport(gatherer) 220 221 // Construct the DTLS transport 222 dtls, err := api.NewDTLSTransport(ice, nil) 223 if err != nil { 224 return nil, err 225 } 226 227 // Construct the SCTP transport 228 sctp := api.NewSCTPTransport(dtls) 229 230 return &testORTCStack{ 231 api: api, 232 gatherer: gatherer, 233 ice: ice, 234 dtls: dtls, 235 sctp: sctp, 236 }, nil 237} 238 239func signalORTCPair(stackA *testORTCStack, stackB *testORTCStack) error { 240 sigA, err := stackA.getSignal() 241 if err != nil { 242 return err 243 } 244 sigB, err := stackB.getSignal() 245 if err != nil { 246 return err 247 } 248 249 a := make(chan error) 250 b := make(chan error) 251 252 go func() { 253 a <- stackB.setSignal(sigA, false) 254 }() 255 256 go func() { 257 b <- stackA.setSignal(sigB, true) 258 }() 259 260 errA := <-a 261 errB := <-b 262 263 closeErrs := []error{errA, errB} 264 265 return util.FlattenErrs(closeErrs) 266} 267