1import os
2import unittest
3from unittest.mock import Mock, patch
4
5import pytest
6import requests_mock
7from Crypto.Cipher import AES
8
9from streamlink.session import Streamlink
10from streamlink.stream import hls
11from tests.mixins.stream_hls import Playlist, Segment, Tag, TestMixinStreamHLS
12from tests.resources import text
13
14
15def pkcs7_encode(data, keySize):
16    val = keySize - (len(data) % keySize)
17    return b''.join([data, bytes(bytearray(val * [val]))])
18
19
20def encrypt(data, key, iv):
21    aesCipher = AES.new(key, AES.MODE_CBC, iv)
22    encrypted_data = aesCipher.encrypt(pkcs7_encode(data, len(key)))
23    return encrypted_data
24
25
26class TagKey(Tag):
27    path = "encryption.key"
28
29    def __init__(self, method="NONE", uri=None, iv=None, keyformat=None, keyformatversions=None):
30        attrs = {"METHOD": method}
31        if uri is not False:  # pragma: no branch
32            attrs.update({"URI": lambda tag, namespace: tag.val_quoted_string(tag.url(namespace))})
33        if iv is not None:  # pragma: no branch
34            attrs.update({"IV": self.val_hex(iv)})
35        if keyformat is not None:  # pragma: no branch
36            attrs.update({"KEYFORMAT": self.val_quoted_string(keyformat)})
37        if keyformatversions is not None:  # pragma: no branch
38            attrs.update({"KEYFORMATVERSIONS": self.val_quoted_string(keyformatversions)})
39        super().__init__("EXT-X-KEY", attrs)
40        self.uri = uri
41
42    def url(self, namespace):
43        return self.uri.format(namespace=namespace) if self.uri else super().url(namespace)
44
45
46class SegmentEnc(Segment):
47    def __init__(self, num, key, iv, *args, **kwargs):
48        super().__init__(num, *args, **kwargs)
49        self.content_plain = self.content
50        self.content = encrypt(self.content, key, iv)
51
52
53class TestHLSStreamRepr(unittest.TestCase):
54    def test_repr(self):
55        session = Streamlink()
56
57        stream = hls.HLSStream(session, "https://foo.bar/playlist.m3u8")
58        self.assertEqual(repr(stream), "<HLSStream('https://foo.bar/playlist.m3u8', None)>")
59
60        stream = hls.HLSStream(session, "https://foo.bar/playlist.m3u8", "https://foo.bar/master.m3u8")
61        self.assertEqual(repr(stream), "<HLSStream('https://foo.bar/playlist.m3u8', 'https://foo.bar/master.m3u8')>")
62
63
64class TestHLSVariantPlaylist(unittest.TestCase):
65    @classmethod
66    def get_master_playlist(cls, playlist):
67        with text(playlist) as pl:
68            return pl.read()
69
70    def subject(self, playlist, options=None):
71        with requests_mock.Mocker() as mock:
72            url = "http://mocked/{0}/master.m3u8".format(self.id())
73            content = self.get_master_playlist(playlist)
74            mock.get(url, text=content)
75
76            session = Streamlink(options)
77
78            return hls.HLSStream.parse_variant_playlist(session, url)
79
80    def test_variant_playlist(self):
81        streams = self.subject("hls/test_master.m3u8")
82        self.assertEqual(
83            list(streams.keys()),
84            ["720p", "720p_alt", "480p", "360p", "160p", "1080p (source)", "90k"],
85            "Finds all streams in master playlist"
86        )
87        self.assertTrue(
88            all([isinstance(stream, hls.HLSStream) for stream in streams.values()]),
89            "Returns HLSStream instances"
90        )
91
92
93@patch("streamlink.stream.hls.HLSStreamWorker.wait", Mock(return_value=True))
94class TestHLSStream(TestMixinStreamHLS, unittest.TestCase):
95    def get_session(self, options=None, *args, **kwargs):
96        session = super().get_session(options)
97        session.set_option("hls-live-edge", 3)
98
99        return session
100
101    def test_offset_and_duration(self):
102        thread, segments = self.subject([
103            Playlist(1234, [Segment(0), Segment(1, duration=0.5), Segment(2, duration=0.5), Segment(3)], end=True)
104        ], streamoptions={"start_offset": 1, "duration": 1})
105
106        data = self.await_read(read_all=True)
107        self.assertEqual(data, self.content(segments, cond=lambda s: 0 < s.num < 3), "Respects the offset and duration")
108        self.assertTrue(all([self.called(s) for s in segments.values() if 0 < s.num < 3]), "Downloads second and third segment")
109        self.assertFalse(any([self.called(s) for s in segments.values() if 0 > s.num > 3]), "Skips other segments")
110
111
112@patch("streamlink.stream.hls.HLSStreamWorker.wait", Mock(return_value=True))
113class TestHLSStreamEncrypted(TestMixinStreamHLS, unittest.TestCase):
114    def get_session(self, options=None, *args, **kwargs):
115        session = super().get_session(options)
116        session.set_option("hls-live-edge", 3)
117
118        return session
119
120    def gen_key(self, aes_key=None, aes_iv=None, method="AES-128", uri=None, keyformat="identity", keyformatversions=1):
121        aes_key = aes_key or os.urandom(16)
122        aes_iv = aes_iv or os.urandom(16)
123
124        key = TagKey(method=method, uri=uri, iv=aes_iv, keyformat=keyformat, keyformatversions=keyformatversions)
125        self.mock("GET", key.url(self.id()), content=aes_key)
126
127        return aes_key, aes_iv, key
128
129    def test_hls_encrypted_aes128(self):
130        aesKey, aesIv, key = self.gen_key()
131
132        # noinspection PyTypeChecker
133        thread, segments = self.subject([
134            Playlist(0, [key] + [SegmentEnc(num, aesKey, aesIv) for num in range(0, 4)]),
135            Playlist(4, [key] + [SegmentEnc(num, aesKey, aesIv) for num in range(4, 8)], end=True)
136        ])
137
138        data = self.await_read(read_all=True)
139        expected = self.content(segments, prop="content_plain", cond=lambda s: s.num >= 1)
140        self.assertEqual(data, expected, "Decrypts the AES-128 identity stream")
141        self.assertTrue(self.called(key), "Downloads encryption key")
142        self.assertFalse(any([self.called(s) for s in segments.values() if s.num < 1]), "Skips first segment")
143        self.assertTrue(all([self.called(s) for s in segments.values() if s.num >= 1]), "Downloads all remaining segments")
144
145    def test_hls_encrypted_aes128_key_uri_override(self):
146        aesKey, aesIv, key = self.gen_key(uri="http://real-mocked/{namespace}/encryption.key?foo=bar")
147        aesKeyInvalid = bytes([ord(aesKey[i:i + 1]) ^ 0xFF for i in range(16)])
148        _, __, key_invalid = self.gen_key(aesKeyInvalid, aesIv, uri="http://mocked/{namespace}/encryption.key?foo=bar")
149
150        # noinspection PyTypeChecker
151        thread, segments = self.subject([
152            Playlist(0, [key_invalid] + [SegmentEnc(num, aesKey, aesIv) for num in range(0, 4)]),
153            Playlist(4, [key_invalid] + [SegmentEnc(num, aesKey, aesIv) for num in range(4, 8)], end=True)
154        ], options={"hls-segment-key-uri": "{scheme}://real-{netloc}{path}?{query}"})
155
156        data = self.await_read(read_all=True)
157        expected = self.content(segments, prop="content_plain", cond=lambda s: s.num >= 1)
158        self.assertEqual(data, expected, "Decrypts stream from custom key")
159        self.assertFalse(self.called(key_invalid), "Skips encryption key")
160        self.assertTrue(self.called(key), "Downloads custom encryption key")
161
162
163@patch("streamlink.stream.hls.HLSStreamWorker.wait", Mock(return_value=True))
164@patch("streamlink.stream.hls.HLSStreamWriter.run", Mock(return_value=True))
165class TestHlsPlaylistReloadTime(TestMixinStreamHLS, unittest.TestCase):
166    segments = [Segment(0, "", 11), Segment(1, "", 7), Segment(2, "", 5), Segment(3, "", 3)]
167
168    def get_session(self, options=None, reload_time=None, *args, **kwargs):
169        return super().get_session(dict(options or {}, **{
170            "hls-live-edge": 3,
171            "hls-playlist-reload-time": reload_time
172        }))
173
174    def subject(self, *args, **kwargs):
175        thread, _ = super().subject(*args, **kwargs)
176        self.await_read(read_all=True)
177
178        return thread.reader.worker.playlist_reload_time
179
180    def test_hls_playlist_reload_time_default(self):
181        time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="default")
182        self.assertEqual(time, 6, "default sets the reload time to the playlist's target duration")
183
184    def test_hls_playlist_reload_time_segment(self):
185        time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="segment")
186        self.assertEqual(time, 3, "segment sets the reload time to the playlist's last segment")
187
188    def test_hls_playlist_reload_time_live_edge(self):
189        time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="live-edge")
190        self.assertEqual(time, 8, "live-edge sets the reload time to the sum of the number of segments of the live-edge")
191
192    def test_hls_playlist_reload_time_number(self):
193        time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="4")
194        self.assertEqual(time, 4, "number values override the reload time")
195
196    def test_hls_playlist_reload_time_number_invalid(self):
197        time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="0")
198        self.assertEqual(time, 6, "invalid number values set the reload time to the playlist's targetduration")
199
200    def test_hls_playlist_reload_time_no_target_duration(self):
201        time = self.subject([Playlist(0, self.segments, end=True, targetduration=0)], reload_time="default")
202        self.assertEqual(time, 8, "uses the live-edge sum if the playlist is missing the targetduration data")
203
204    def test_hls_playlist_reload_time_no_data(self):
205        time = self.subject([Playlist(0, [], end=True, targetduration=0)], reload_time="default")
206        self.assertEqual(time, 15, "sets reload time to 15 seconds when no data is available")
207
208
209@patch('streamlink.stream.hls.FFMPEGMuxer.is_usable', Mock(return_value=True))
210class TestHlsExtAudio(unittest.TestCase):
211    @property
212    def playlist(self):
213        with text("hls/test_2.m3u8") as pl:
214            return pl.read()
215
216    def run_streamlink(self, playlist, audio_select=None):
217        streamlink = Streamlink()
218
219        if audio_select:
220            streamlink.set_option("hls-audio-select", audio_select)
221
222        master_stream = hls.HLSStream.parse_variant_playlist(streamlink, playlist)
223
224        return master_stream
225
226    def test_hls_ext_audio_not_selected(self):
227        master_url = "http://mocked/path/master.m3u8"
228
229        with requests_mock.Mocker() as mock:
230            mock.get(master_url, text=self.playlist)
231            master_stream = self.run_streamlink(master_url)['video']
232
233        with pytest.raises(AttributeError):
234            master_stream.substreams
235
236        assert master_stream.url == 'http://mocked/path/playlist.m3u8'
237
238    def test_hls_ext_audio_en(self):
239        """
240        m3u8 with ext audio but no options should not download additional streams
241        :return:
242        """
243
244        master_url = "http://mocked/path/master.m3u8"
245        expected = ['http://mocked/path/playlist.m3u8', 'http://mocked/path/en.m3u8']
246
247        with requests_mock.Mocker() as mock:
248            mock.get(master_url, text=self.playlist)
249            master_stream = self.run_streamlink(master_url, 'en')
250
251        substreams = master_stream['video'].substreams
252        result = [x.url for x in substreams]
253
254        # Check result
255        self.assertEqual(result, expected)
256
257    def test_hls_ext_audio_es(self):
258        """
259        m3u8 with ext audio but no options should not download additional streams
260        :return:
261        """
262
263        master_url = "http://mocked/path/master.m3u8"
264        expected = ['http://mocked/path/playlist.m3u8', 'http://mocked/path/es.m3u8']
265
266        with requests_mock.Mocker() as mock:
267            mock.get(master_url, text=self.playlist)
268            master_stream = self.run_streamlink(master_url, 'es')
269
270        substreams = master_stream['video'].substreams
271
272        result = [x.url for x in substreams]
273
274        # Check result
275        self.assertEqual(result, expected)
276
277    def test_hls_ext_audio_all(self):
278        """
279        m3u8 with ext audio but no options should not download additional streams
280        :return:
281        """
282
283        master_url = "http://mocked/path/master.m3u8"
284        expected = ['http://mocked/path/playlist.m3u8', 'http://mocked/path/en.m3u8', 'http://mocked/path/es.m3u8']
285
286        with requests_mock.Mocker() as mock:
287            mock.get(master_url, text=self.playlist)
288            master_stream = self.run_streamlink(master_url, 'en,es')
289
290        substreams = master_stream['video'].substreams
291
292        result = [x.url for x in substreams]
293
294        # Check result
295        self.assertEqual(result, expected)
296
297    def test_hls_ext_audio_wildcard(self):
298        master_url = "http://mocked/path/master.m3u8"
299        expected = ['http://mocked/path/playlist.m3u8', 'http://mocked/path/en.m3u8', 'http://mocked/path/es.m3u8']
300
301        with requests_mock.Mocker() as mock:
302            mock.get(master_url, text=self.playlist)
303            master_stream = self.run_streamlink(master_url, '*')
304
305        substreams = master_stream['video'].substreams
306
307        result = [x.url for x in substreams]
308
309        # Check result
310        self.assertEqual(result, expected)
311