1// +build !js
2
3package webrtc
4
5import (
6	"io"
7	"testing"
8	"time"
9
10	"github.com/pion/transport/test"
11	"github.com/pion/webrtc/v3/internal/util"
12	"github.com/stretchr/testify/assert"
13)
14
15func TestDataChannel_ORTCE2E(t *testing.T) {
16	// Limit runtime in case of deadlocks
17	lim := test.TimeOut(time.Second * 20)
18	defer lim.Stop()
19
20	report := test.CheckRoutines(t)
21	defer report()
22
23	stackA, stackB, err := newORTCPair()
24	if err != nil {
25		t.Fatal(err)
26	}
27
28	awaitSetup := make(chan struct{})
29	awaitString := make(chan struct{})
30	awaitBinary := make(chan struct{})
31	stackB.sctp.OnDataChannel(func(d *DataChannel) {
32		close(awaitSetup)
33
34		d.OnMessage(func(msg DataChannelMessage) {
35			if msg.IsString {
36				close(awaitString)
37			} else {
38				close(awaitBinary)
39			}
40		})
41	})
42
43	err = signalORTCPair(stackA, stackB)
44	if err != nil {
45		t.Fatal(err)
46	}
47
48	var id uint16 = 1
49	dcParams := &DataChannelParameters{
50		Label: "Foo",
51		ID:    &id,
52	}
53	channelA, err := stackA.api.NewDataChannel(stackA.sctp, dcParams)
54	if err != nil {
55		t.Fatal(err)
56	}
57
58	<-awaitSetup
59
60	err = channelA.SendText("ABC")
61	if err != nil {
62		t.Fatal(err)
63	}
64	err = channelA.Send([]byte("ABC"))
65	if err != nil {
66		t.Fatal(err)
67	}
68	<-awaitString
69	<-awaitBinary
70
71	err = stackA.close()
72	if err != nil {
73		t.Fatal(err)
74	}
75
76	err = stackB.close()
77	if err != nil {
78		t.Fatal(err)
79	}
80
81	// attempt to send when channel is closed
82	err = channelA.Send([]byte("ABC"))
83	assert.Error(t, err)
84	assert.Equal(t, io.ErrClosedPipe, err)
85
86	err = channelA.SendText("test")
87	assert.Error(t, err)
88	assert.Equal(t, io.ErrClosedPipe, err)
89
90	err = channelA.ensureOpen()
91	assert.Error(t, err)
92	assert.Equal(t, io.ErrClosedPipe, err)
93}
94
95type testORTCStack struct {
96	api      *API
97	gatherer *ICEGatherer
98	ice      *ICETransport
99	dtls     *DTLSTransport
100	sctp     *SCTPTransport
101}
102
103func (s *testORTCStack) setSignal(sig *testORTCSignal, isOffer bool) error {
104	iceRole := ICERoleControlled
105	if isOffer {
106		iceRole = ICERoleControlling
107	}
108
109	err := s.ice.SetRemoteCandidates(sig.ICECandidates)
110	if err != nil {
111		return err
112	}
113
114	// Start the ICE transport
115	err = s.ice.Start(nil, sig.ICEParameters, &iceRole)
116	if err != nil {
117		return err
118	}
119
120	// Start the DTLS transport
121	err = s.dtls.Start(sig.DTLSParameters)
122	if err != nil {
123		return err
124	}
125
126	// Start the SCTP transport
127	err = s.sctp.Start(sig.SCTPCapabilities)
128	if err != nil {
129		return err
130	}
131
132	return nil
133}
134
135func (s *testORTCStack) getSignal() (*testORTCSignal, error) {
136	gatherFinished := make(chan struct{})
137	s.gatherer.OnLocalCandidate(func(i *ICECandidate) {
138		if i == nil {
139			close(gatherFinished)
140		}
141	})
142
143	if err := s.gatherer.Gather(); err != nil {
144		return nil, err
145	}
146
147	<-gatherFinished
148	iceCandidates, err := s.gatherer.GetLocalCandidates()
149	if err != nil {
150		return nil, err
151	}
152
153	iceParams, err := s.gatherer.GetLocalParameters()
154	if err != nil {
155		return nil, err
156	}
157
158	dtlsParams, err := s.dtls.GetLocalParameters()
159	if err != nil {
160		return nil, err
161	}
162
163	sctpCapabilities := s.sctp.GetCapabilities()
164
165	return &testORTCSignal{
166		ICECandidates:    iceCandidates,
167		ICEParameters:    iceParams,
168		DTLSParameters:   dtlsParams,
169		SCTPCapabilities: sctpCapabilities,
170	}, nil
171}
172
173func (s *testORTCStack) close() error {
174	var closeErrs []error
175
176	if err := s.sctp.Stop(); err != nil {
177		closeErrs = append(closeErrs, err)
178	}
179
180	if err := s.ice.Stop(); err != nil {
181		closeErrs = append(closeErrs, err)
182	}
183
184	return util.FlattenErrs(closeErrs)
185}
186
187type testORTCSignal struct {
188	ICECandidates    []ICECandidate   `json:"iceCandidates"`
189	ICEParameters    ICEParameters    `json:"iceParameters"`
190	DTLSParameters   DTLSParameters   `json:"dtlsParameters"`
191	SCTPCapabilities SCTPCapabilities `json:"sctpCapabilities"`
192}
193
194func newORTCPair() (stackA *testORTCStack, stackB *testORTCStack, err error) {
195	sa, err := newORTCStack()
196	if err != nil {
197		return nil, nil, err
198	}
199
200	sb, err := newORTCStack()
201	if err != nil {
202		return nil, nil, err
203	}
204
205	return sa, sb, nil
206}
207
208func newORTCStack() (*testORTCStack, error) {
209	// Create an API object
210	api := NewAPI()
211
212	// Create the ICE gatherer
213	gatherer, err := api.NewICEGatherer(ICEGatherOptions{})
214	if err != nil {
215		return nil, err
216	}
217
218	// Construct the ICE transport
219	ice := api.NewICETransport(gatherer)
220
221	// Construct the DTLS transport
222	dtls, err := api.NewDTLSTransport(ice, nil)
223	if err != nil {
224		return nil, err
225	}
226
227	// Construct the SCTP transport
228	sctp := api.NewSCTPTransport(dtls)
229
230	return &testORTCStack{
231		api:      api,
232		gatherer: gatherer,
233		ice:      ice,
234		dtls:     dtls,
235		sctp:     sctp,
236	}, nil
237}
238
239func signalORTCPair(stackA *testORTCStack, stackB *testORTCStack) error {
240	sigA, err := stackA.getSignal()
241	if err != nil {
242		return err
243	}
244	sigB, err := stackB.getSignal()
245	if err != nil {
246		return err
247	}
248
249	a := make(chan error)
250	b := make(chan error)
251
252	go func() {
253		a <- stackB.setSignal(sigA, false)
254	}()
255
256	go func() {
257		b <- stackA.setSignal(sigB, true)
258	}()
259
260	errA := <-a
261	errB := <-b
262
263	closeErrs := []error{errA, errB}
264
265	return util.FlattenErrs(closeErrs)
266}
267