1import unittest
2from datetime import datetime, timedelta
3from unittest.mock import MagicMock, call, patch
4
5import requests_mock
6
7from streamlink import Streamlink
8from streamlink.plugin import PluginError
9from streamlink.plugins.twitch import Twitch, TwitchHLSStream, TwitchHLSStreamReader, TwitchHLSStreamWriter
10from tests.mixins.stream_hls import EventedHLSStreamWriter, Playlist, Segment as _Segment, Tag, TestMixinStreamHLS
11from tests.plugins import PluginCanHandleUrl
12
13
14class TestPluginCanHandleUrlTwitch(PluginCanHandleUrl):
15    __plugin__ = Twitch
16
17    should_match = [
18        'https://www.twitch.tv/twitch',
19        'https://www.twitch.tv/videos/150942279',
20        'https://clips.twitch.tv/ObservantBenevolentCarabeefPhilosoraptor',
21        'https://www.twitch.tv/twitch/video/292713971',
22        'https://www.twitch.tv/twitch/v/292713971',
23    ]
24
25    should_not_match = [
26        'https://www.twitch.tv',
27        'https://www.twitch.tv/',
28    ]
29
30
31DATETIME_BASE = datetime(2000, 1, 1, 0, 0, 0, 0)
32DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
33
34
35class TagDateRangeAd(Tag):
36    def __init__(self, start=DATETIME_BASE, duration=1, id="stitched-ad-1234", classname="twitch-stitched-ad", custom=None):
37        attrs = {
38            "ID": self.val_quoted_string(id),
39            "CLASS": self.val_quoted_string(classname),
40            "START-DATE": self.val_quoted_string(start.strftime(DATETIME_FORMAT)),
41            "DURATION": duration
42        }
43        if custom is not None:
44            attrs.update(**{key: self.val_quoted_string(value) for (key, value) in custom.items()})
45        super().__init__("EXT-X-DATERANGE", attrs)
46
47
48class Segment(_Segment):
49    def __init__(self, num, title="live", *args, **kwargs):
50        super().__init__(num, title, *args, **kwargs)
51        self.date = DATETIME_BASE + timedelta(seconds=num)
52
53    def build(self, namespace):
54        return "#EXT-X-PROGRAM-DATE-TIME:{0}\n{1}".format(
55            self.date.strftime(DATETIME_FORMAT),
56            super().build(namespace)
57        )
58
59
60class SegmentPrefetch(Segment):
61    def build(self, namespace):
62        return "#EXT-X-TWITCH-PREFETCH:{0}".format(self.url(namespace))
63
64
65class _TwitchHLSStreamWriter(EventedHLSStreamWriter, TwitchHLSStreamWriter):
66    pass
67
68
69class _TwitchHLSStreamReader(TwitchHLSStreamReader):
70    __writer__ = _TwitchHLSStreamWriter
71
72
73class _TwitchHLSStream(TwitchHLSStream):
74    __reader__ = _TwitchHLSStreamReader
75
76
77@patch("streamlink.stream.hls.HLSStreamWorker.wait", MagicMock(return_value=True))
78class TestTwitchHLSStream(TestMixinStreamHLS, unittest.TestCase):
79    __stream__ = _TwitchHLSStream
80
81    def get_session(self, options=None, disable_ads=False, low_latency=False):
82        session = super().get_session(options)
83        session.set_option("hls-live-edge", 4)
84        session.set_plugin_option("twitch", "disable-ads", disable_ads)
85        session.set_plugin_option("twitch", "low-latency", low_latency)
86
87        return session
88
89    def test_hls_disable_ads_daterange_unknown(self):
90        daterange = TagDateRangeAd(start=DATETIME_BASE, duration=1, id="foo", classname="bar", custom=None)
91        thread, segments = self.subject([
92            Playlist(0, [daterange, Segment(0), Segment(1)], end=True)
93        ], disable_ads=True, low_latency=False)
94
95        self.await_write(2)
96        self.assertEqual(self.await_read(read_all=True), self.content(segments), "Doesn't filter out segments")
97        self.assertTrue(all(self.called(s) for s in segments.values()), "Downloads all segments")
98
99    def test_hls_disable_ads_daterange_by_class(self):
100        daterange = TagDateRangeAd(start=DATETIME_BASE, duration=1, id="foo", classname="twitch-stitched-ad", custom=None)
101        thread, segments = self.subject([
102            Playlist(0, [daterange, Segment(0), Segment(1)], end=True)
103        ], disable_ads=True, low_latency=False)
104
105        self.await_write(2)
106        self.assertEqual(self.await_read(read_all=True), segments[1].content, "Filters out ad segments")
107        self.assertTrue(all(self.called(s) for s in segments.values()), "Downloads all segments")
108
109    def test_hls_disable_ads_daterange_by_id(self):
110        daterange = TagDateRangeAd(start=DATETIME_BASE, duration=1, id="stitched-ad-1234", classname="/", custom=None)
111        thread, segments = self.subject([
112            Playlist(0, [daterange, Segment(0), Segment(1)], end=True)
113        ], disable_ads=True, low_latency=False)
114
115        self.await_write(2)
116        self.assertEqual(self.await_read(read_all=True), segments[1].content, "Filters out ad segments")
117        self.assertTrue(all(self.called(s) for s in segments.values()), "Downloads all segments")
118
119    def test_hls_disable_ads_daterange_by_attr(self):
120        daterange = TagDateRangeAd(start=DATETIME_BASE, duration=1, id="foo", classname="/", custom={"X-TV-TWITCH-AD-URL": "/"})
121        thread, segments = self.subject([
122            Playlist(0, [daterange, Segment(0), Segment(1)], end=True)
123        ], disable_ads=True, low_latency=False)
124
125        self.await_write(2)
126        self.assertEqual(self.await_read(read_all=True), segments[1].content, "Filters out ad segments")
127        self.assertTrue(all(self.called(s) for s in segments.values()), "Downloads all segments")
128
129    @patch("streamlink.plugins.twitch.log")
130    def test_hls_disable_ads_has_preroll(self, mock_log):
131        daterange = TagDateRangeAd(duration=4)
132        thread, segments = self.subject([
133            Playlist(0, [daterange, Segment(0), Segment(1)]),
134            Playlist(2, [daterange, Segment(2), Segment(3)]),
135            Playlist(4, [Segment(4), Segment(5)], end=True)
136        ], disable_ads=True, low_latency=False)
137
138        self.await_write(6)
139        self.assertEqual(
140            self.await_read(read_all=True),
141            self.content(segments, cond=lambda s: s.num >= 4),
142            "Filters out preroll ad segments"
143        )
144        self.assertTrue(all([self.called(s) for s in segments.values()]), "Downloads all segments")
145        self.assertEqual(mock_log.info.mock_calls, [
146            call("Will skip ad segments"),
147            call("Waiting for pre-roll ads to finish, be patient")
148        ])
149
150    @patch("streamlink.plugins.twitch.log")
151    def test_hls_disable_ads_has_midstream(self, mock_log):
152        daterange = TagDateRangeAd(start=DATETIME_BASE + timedelta(seconds=2), duration=2)
153        thread, segments = self.subject([
154            Playlist(0, [Segment(0), Segment(1)]),
155            Playlist(2, [daterange, Segment(2), Segment(3)]),
156            Playlist(4, [Segment(4), Segment(5)], end=True)
157        ], disable_ads=True, low_latency=False)
158
159        self.await_write(6)
160        self.assertEqual(
161            self.await_read(read_all=True),
162            self.content(segments, cond=lambda s: s.num != 2 and s.num != 3),
163            "Filters out mid-stream ad segments"
164        )
165        self.assertTrue(all([self.called(s) for s in segments.values()]), "Downloads all segments")
166        self.assertEqual(mock_log.info.mock_calls, [
167            call("Will skip ad segments")
168        ])
169
170    @patch("streamlink.plugins.twitch.log")
171    def test_hls_no_disable_ads_has_preroll(self, mock_log):
172        daterange = TagDateRangeAd(duration=2)
173        thread, segments = self.subject([
174            Playlist(0, [daterange, Segment(0), Segment(1)]),
175            Playlist(2, [Segment(2), Segment(3)], end=True)
176        ], disable_ads=False, low_latency=False)
177
178        self.await_write(4)
179        self.assertEqual(
180            self.await_read(read_all=True),
181            self.content(segments),
182            "Doesn't filter out segments"
183        )
184        self.assertTrue(all([self.called(s) for s in segments.values()]), "Downloads all segments")
185        self.assertEqual(mock_log.info.mock_calls, [], "Doesn't log anything")
186
187    @patch("streamlink.plugins.twitch.log")
188    def test_hls_low_latency_has_prefetch(self, mock_log):
189        thread, segments = self.subject([
190            Playlist(0, [Segment(0), Segment(1), Segment(2), Segment(3), SegmentPrefetch(4), SegmentPrefetch(5)]),
191            Playlist(4, [Segment(4), Segment(5), Segment(6), Segment(7), SegmentPrefetch(8), SegmentPrefetch(9)], end=True)
192        ], disable_ads=False, low_latency=True)
193
194        self.assertEqual(2, self.session.options.get("hls-live-edge"))
195        self.assertEqual(True, self.session.options.get("hls-segment-stream-data"))
196
197        self.await_write(6)
198        self.assertEqual(
199            self.await_read(read_all=True),
200            self.content(segments, cond=lambda s: s.num >= 4),
201            "Skips first four segments due to reduced live-edge"
202        )
203        self.assertFalse(any([self.called(s) for s in segments.values() if s.num < 4]), "Doesn't download old segments")
204        self.assertTrue(all([self.called(s) for s in segments.values() if s.num >= 4]), "Downloads all remaining segments")
205        self.assertEqual(mock_log.info.mock_calls, [
206            call("Low latency streaming (HLS live edge: 2)")
207        ])
208
209    @patch("streamlink.plugins.twitch.log")
210    def test_hls_no_low_latency_has_prefetch(self, mock_log):
211        thread, segments = self.subject([
212            Playlist(0, [Segment(0), Segment(1), Segment(2), Segment(3), SegmentPrefetch(4), SegmentPrefetch(5)]),
213            Playlist(4, [Segment(4), Segment(5), Segment(6), Segment(7), SegmentPrefetch(8), SegmentPrefetch(9)], end=True)
214        ], disable_ads=False, low_latency=False)
215
216        self.assertEqual(4, self.session.options.get("hls-live-edge"))
217        self.assertEqual(False, self.session.options.get("hls-segment-stream-data"))
218
219        self.await_write(8)
220        self.assertEqual(
221            self.await_read(read_all=True),
222            self.content(segments, cond=lambda s: s.num < 8),
223            "Ignores prefetch segments"
224        )
225        self.assertTrue(all([self.called(s) for s in segments.values() if s.num <= 7]), "Ignores prefetch segments")
226        self.assertFalse(any([self.called(s) for s in segments.values() if s.num > 7]), "Ignores prefetch segments")
227        self.assertEqual(mock_log.info.mock_calls, [], "Doesn't log anything")
228
229    @patch("streamlink.plugins.twitch.log")
230    def test_hls_low_latency_no_prefetch(self, mock_log):
231        self.subject([
232            Playlist(0, [Segment(0), Segment(1), Segment(2), Segment(3)]),
233            Playlist(4, [Segment(4), Segment(5), Segment(6), Segment(7)], end=True)
234        ], disable_ads=False, low_latency=True)
235
236        self.assertTrue(self.session.get_plugin_option("twitch", "low-latency"))
237        self.assertFalse(self.session.get_plugin_option("twitch", "disable-ads"))
238
239        self.await_write(6)
240        self.await_read(read_all=True)
241        self.assertEqual(mock_log.info.mock_calls, [
242            call("Low latency streaming (HLS live edge: 2)"),
243            call("This is not a low latency stream")
244        ])
245
246    @patch("streamlink.plugins.twitch.log")
247    def test_hls_low_latency_has_prefetch_has_preroll(self, mock_log):
248        daterange = TagDateRangeAd(duration=4)
249        thread, segments = self.subject([
250            Playlist(0, [daterange, Segment(0), Segment(1), Segment(2), Segment(3)]),
251            Playlist(4, [Segment(4), Segment(5), Segment(6), Segment(7), SegmentPrefetch(8), SegmentPrefetch(9)], end=True)
252        ], disable_ads=False, low_latency=True)
253
254        self.await_write(8)
255        self.assertEqual(
256            self.await_read(read_all=True),
257            self.content(segments, cond=lambda s: s.num > 1),
258            "Skips first two segments due to reduced live-edge"
259        )
260        self.assertFalse(any([self.called(s) for s in segments.values() if s.num < 2]), "Skips first two preroll segments")
261        self.assertTrue(all([self.called(s) for s in segments.values() if s.num >= 2]), "Downloads all remaining segments")
262        self.assertEqual(mock_log.info.mock_calls, [
263            call("Low latency streaming (HLS live edge: 2)")
264        ])
265
266    @patch("streamlink.plugins.twitch.log")
267    def test_hls_low_latency_has_prefetch_disable_ads_has_preroll(self, mock_log):
268        daterange = TagDateRangeAd(duration=4)
269        self.subject([
270            Playlist(0, [daterange, Segment(0), Segment(1), Segment(2), Segment(3)]),
271            Playlist(4, [Segment(4), Segment(5), Segment(6), Segment(7), SegmentPrefetch(8), SegmentPrefetch(9)], end=True)
272        ], disable_ads=True, low_latency=True)
273
274        self.await_write(8)
275        self.await_read(read_all=True)
276        self.assertEqual(mock_log.info.mock_calls, [
277            call("Will skip ad segments"),
278            call("Low latency streaming (HLS live edge: 2)"),
279            call("Waiting for pre-roll ads to finish, be patient")
280        ])
281
282    @patch("streamlink.plugins.twitch.log")
283    def test_hls_low_latency_no_prefetch_disable_ads_has_preroll(self, mock_log):
284        daterange = TagDateRangeAd(duration=4)
285        self.subject([
286            Playlist(0, [daterange, Segment(0), Segment(1), Segment(2), Segment(3)]),
287            Playlist(4, [Segment(4), Segment(5), Segment(6), Segment(7)], end=True)
288        ], disable_ads=True, low_latency=True)
289
290        self.await_write(6)
291        self.await_read(read_all=True)
292        self.assertEqual(mock_log.info.mock_calls, [
293            call("Will skip ad segments"),
294            call("Low latency streaming (HLS live edge: 2)"),
295            call("Waiting for pre-roll ads to finish, be patient"),
296            call("This is not a low latency stream")
297        ])
298
299
300class TestTwitchMetadata(unittest.TestCase):
301    def setUp(self):
302        self.mock = requests_mock.Mocker()
303        self.mock.start()
304
305    def tearDown(self):
306        self.mock.stop()
307
308    def subject(self, url):
309        session = Streamlink()
310        Twitch.bind(session, "tests.plugins.test_twitch")
311        plugin = Twitch(url)
312        return plugin.get_author(), plugin.get_title(), plugin.get_category()
313
314    def subject_channel(self, data=True, failure=False):
315        self.mock.get(
316            "https://api.twitch.tv/kraken/users?login=foo",
317            json={"users": [{"_id": 1234}]}
318        )
319        self.mock.get(
320            "https://api.twitch.tv/kraken/streams/1234",
321            status_code=200 if not failure else 404,
322            json={"stream": None} if not data else {"stream": {
323                "channel": {
324                    "display_name": "channel name",
325                    "status": "channel status",
326                    "game": "channel game"
327                }
328            }}
329        )
330        return self.subject("https://twitch.tv/foo")
331
332    def subject_video(self, data=True, failure=False):
333        self.mock.get(
334            "https://api.twitch.tv/kraken/videos/1337",
335            status_code=200 if not failure else 404,
336            json={} if not data else {
337                "title": "video title",
338                "game": "video game",
339                "channel": {
340                    "display_name": "channel name"
341                }
342            }
343        )
344        return self.subject("https://twitch.tv/videos/1337")
345
346    def test_metadata_channel_exists(self):
347        author, title, category = self.subject_channel()
348        self.assertEqual(author, "channel name")
349        self.assertEqual(title, "channel status")
350        self.assertEqual(category, "channel game")
351
352    def test_metadata_channel_missing(self):
353        metadata = self.subject_channel(data=False)
354        self.assertEqual(metadata, (None, None, None))
355
356    def test_metadata_channel_invalid(self):
357        with self.assertRaises(PluginError):
358            self.subject_channel(failure=True)
359
360    def test_metadata_video_exists(self):
361        author, title, category = self.subject_video()
362        self.assertEqual(author, "channel name")
363        self.assertEqual(title, "video title")
364        self.assertEqual(category, "video game")
365
366    def test_metadata_video_missing(self):
367        metadata = self.subject_video(data=False)
368        self.assertEqual(metadata, (None, None, None))
369
370    def test_metadata_video_invalid(self):
371        with self.assertRaises(PluginError):
372            self.subject_video(failure=True)
373
374
375@patch("streamlink.plugins.twitch.log")
376class TestTwitchHosting(unittest.TestCase):
377    def subject(self, channel, hosts=None, disable=False):
378        with requests_mock.Mocker() as mock:
379            mock.get(
380                "https://api.twitch.tv/kraken/users?login=foo",
381                json={"users": [{"_id": 1}]}
382            )
383            if hosts is None:
384                mock.get("https://tmi.twitch.tv/hosts", json={})
385            else:
386                mock.get(
387                    "https://tmi.twitch.tv/hosts",
388                    [{"json": {
389                        "hosts": [dict(
390                            host_id=host_id,
391                            target_id=target_id,
392                            target_login=target_login,
393                            target_display_name=target_display_name
394                        )]}
395                      } for host_id, target_id, target_login, target_display_name in hosts]
396                )
397
398            session = Streamlink()
399            Twitch.bind(session, "tests.plugins.test_twitch")
400            plugin = Twitch("https://twitch.tv/{0}".format(channel))
401            plugin.options.set("disable-hosting", disable)
402
403            res = plugin._switch_to_hosted_channel()
404            return res, plugin.channel, plugin._channel_id, plugin.author
405
406    def test_hosting_invalid_host_data(self, mock_log):
407        res, channel, channel_id, author = self.subject("foo")
408        self.assertFalse(res, "Doesn't stop HLS resolve procedure")
409        self.assertEqual(channel, "foo", "Doesn't switch channel")
410        self.assertEqual(channel_id, 1, "Doesn't switch channel id")
411        self.assertEqual(author, None, "Doesn't override author metadata")
412        self.assertEqual(mock_log.info.mock_calls, [], "Doesn't log anything to info")
413        self.assertEqual(mock_log.error.mock_calls, [], "Doesn't log anything to error")
414
415    def test_hosting_no_host_data(self, mock_log):
416        res, channel, channel_id, author = self.subject("foo", [(1, None, None, None)])
417        self.assertFalse(res, "Doesn't stop HLS resolve procedure")
418        self.assertEqual(channel, "foo", "Doesn't switch channel")
419        self.assertEqual(channel_id, 1, "Doesn't switch channel id")
420        self.assertEqual(author, None, "Doesn't override author metadata")
421        self.assertEqual(mock_log.info.mock_calls, [], "Doesn't log anything to info")
422        self.assertEqual(mock_log.error.mock_calls, [], "Doesn't log anything to error")
423
424    def test_hosting_host_single(self, mock_log):
425        res, channel, channel_id, author = self.subject("foo", [(1, 2, "bar", "Bar"), (2, None, None, None)])
426        self.assertFalse(res, "Doesn't stop HLS resolve procedure")
427        self.assertEqual(channel, "bar", "Switches channel")
428        self.assertEqual(channel_id, 2, "Switches channel id")
429        self.assertEqual(author, "Bar", "Overrides author metadata")
430        self.assertEqual(mock_log.info.mock_calls, [
431            call("foo is hosting bar"),
432            call("switching to bar")
433        ])
434        self.assertEqual(mock_log.error.mock_calls, [], "Doesn't log anything to error")
435
436    def test_hosting_host_single_disable(self, mock_log):
437        res, channel, channel_id, author = self.subject("foo", [(1, 2, "bar", "Bar")], disable=True)
438        self.assertTrue(res, "Stops HLS resolve procedure")
439        self.assertEqual(channel, "foo", "Doesn't switch channel")
440        self.assertEqual(channel_id, 1, "Doesn't switch channel id")
441        self.assertEqual(author, None, "Doesn't override author metadata")
442        self.assertEqual(mock_log.info.mock_calls, [
443            call("foo is hosting bar"),
444            call("hosting was disabled by command line option")
445        ])
446        self.assertEqual(mock_log.error.mock_calls, [], "Doesn't log anything to error")
447
448    def test_hosting_host_multiple(self, mock_log):
449        res, channel, channel_id, author = self.subject("foo", [
450            (1, 2, "bar", "Bar"),
451            (2, 3, "baz", "Baz"),
452            (3, 4, "qux", "Qux"),
453            (4, None, None, None)
454        ])
455        self.assertFalse(res, "Doesn't stop HLS resolve procedure")
456        self.assertEqual(channel, "qux", "Switches channel")
457        self.assertEqual(channel_id, 4, "Switches channel id")
458        self.assertEqual(author, "Qux", "Overrides author metadata")
459        self.assertEqual(mock_log.info.mock_calls, [
460            call("foo is hosting bar"),
461            call("switching to bar"),
462            call("bar is hosting baz"),
463            call("switching to baz"),
464            call("baz is hosting qux"),
465            call("switching to qux")
466        ])
467        self.assertEqual(mock_log.error.mock_calls, [], "Doesn't log anything to error")
468
469    def test_hosting_host_multiple_loop(self, mock_log):
470        res, channel, channel_id, author = self.subject("foo", [
471            (1, 2, "bar", "Bar"),
472            (2, 3, "baz", "Baz"),
473            (3, 1, "foo", "Foo")
474        ])
475        self.assertTrue(res, "Stops HLS resolve procedure")
476        self.assertEqual(channel, "baz", "Has switched channel")
477        self.assertEqual(channel_id, 3, "Has switched channel id")
478        self.assertEqual(author, "Baz", "Has overridden author metadata")
479        self.assertEqual(mock_log.info.mock_calls, [
480            call("foo is hosting bar"),
481            call("switching to bar"),
482            call("bar is hosting baz"),
483            call("switching to baz"),
484            call("baz is hosting foo")
485        ])
486        self.assertEqual(mock_log.error.mock_calls, [
487            call("A loop of hosted channels has been detected, cannot find a playable stream. (foo -> bar -> baz -> foo)")
488        ])
489
490
491@patch("streamlink.plugins.twitch.log")
492class TestTwitchReruns(unittest.TestCase):
493    log_call = call("Reruns were disabled by command line option")
494
495    def subject(self, **params):
496        with patch("streamlink.plugins.twitch.TwitchAPI.stream_metadata") as mock:
497            mock.return_value = None if params.pop("offline", False) else {"type": params.pop("stream_type", "live")}
498            session = Streamlink()
499            Twitch.bind(session, "tests.plugins.test_twitch")
500            plugin = Twitch("https://www.twitch.tv/foo")
501            plugin.options.set("disable-reruns", params.pop("disable", True))
502
503            return plugin._check_for_rerun()
504
505    def test_disable_reruns_live(self, mock_log):
506        self.assertFalse(self.subject())
507        self.assertNotIn(self.log_call, mock_log.info.call_args_list)
508
509    def test_disable_reruns_not_live(self, mock_log):
510        self.assertTrue(self.subject(stream_type="rerun"))
511        self.assertIn(self.log_call, mock_log.info.call_args_list)
512
513    def test_disable_reruns_offline(self, mock_log):
514        self.assertFalse(self.subject(offline=True))
515        self.assertNotIn(self.log_call, mock_log.info.call_args_list)
516
517    def test_enable_reruns(self, mock_log):
518        self.assertFalse(self.subject(stream_type="rerun", disable=False))
519        self.assertNotIn(self.log_call, mock_log.info.call_args_list)
520