1import os 2import unittest 3from unittest.mock import Mock, patch 4 5import pytest 6import requests_mock 7from Crypto.Cipher import AES 8 9from streamlink.session import Streamlink 10from streamlink.stream import hls 11from tests.mixins.stream_hls import Playlist, Segment, Tag, TestMixinStreamHLS 12from tests.resources import text 13 14 15def pkcs7_encode(data, keySize): 16 val = keySize - (len(data) % keySize) 17 return b''.join([data, bytes(bytearray(val * [val]))]) 18 19 20def encrypt(data, key, iv): 21 aesCipher = AES.new(key, AES.MODE_CBC, iv) 22 encrypted_data = aesCipher.encrypt(pkcs7_encode(data, len(key))) 23 return encrypted_data 24 25 26class TagKey(Tag): 27 path = "encryption.key" 28 29 def __init__(self, method="NONE", uri=None, iv=None, keyformat=None, keyformatversions=None): 30 attrs = {"METHOD": method} 31 if uri is not False: # pragma: no branch 32 attrs.update({"URI": lambda tag, namespace: tag.val_quoted_string(tag.url(namespace))}) 33 if iv is not None: # pragma: no branch 34 attrs.update({"IV": self.val_hex(iv)}) 35 if keyformat is not None: # pragma: no branch 36 attrs.update({"KEYFORMAT": self.val_quoted_string(keyformat)}) 37 if keyformatversions is not None: # pragma: no branch 38 attrs.update({"KEYFORMATVERSIONS": self.val_quoted_string(keyformatversions)}) 39 super().__init__("EXT-X-KEY", attrs) 40 self.uri = uri 41 42 def url(self, namespace): 43 return self.uri.format(namespace=namespace) if self.uri else super().url(namespace) 44 45 46class SegmentEnc(Segment): 47 def __init__(self, num, key, iv, *args, **kwargs): 48 super().__init__(num, *args, **kwargs) 49 self.content_plain = self.content 50 self.content = encrypt(self.content, key, iv) 51 52 53class TestHLSStreamRepr(unittest.TestCase): 54 def test_repr(self): 55 session = Streamlink() 56 57 stream = hls.HLSStream(session, "https://foo.bar/playlist.m3u8") 58 self.assertEqual(repr(stream), "<HLSStream('https://foo.bar/playlist.m3u8', None)>") 59 60 stream = hls.HLSStream(session, "https://foo.bar/playlist.m3u8", "https://foo.bar/master.m3u8") 61 self.assertEqual(repr(stream), "<HLSStream('https://foo.bar/playlist.m3u8', 'https://foo.bar/master.m3u8')>") 62 63 64class TestHLSVariantPlaylist(unittest.TestCase): 65 @classmethod 66 def get_master_playlist(cls, playlist): 67 with text(playlist) as pl: 68 return pl.read() 69 70 def subject(self, playlist, options=None): 71 with requests_mock.Mocker() as mock: 72 url = "http://mocked/{0}/master.m3u8".format(self.id()) 73 content = self.get_master_playlist(playlist) 74 mock.get(url, text=content) 75 76 session = Streamlink(options) 77 78 return hls.HLSStream.parse_variant_playlist(session, url) 79 80 def test_variant_playlist(self): 81 streams = self.subject("hls/test_master.m3u8") 82 self.assertEqual( 83 list(streams.keys()), 84 ["720p", "720p_alt", "480p", "360p", "160p", "1080p (source)", "90k"], 85 "Finds all streams in master playlist" 86 ) 87 self.assertTrue( 88 all([isinstance(stream, hls.HLSStream) for stream in streams.values()]), 89 "Returns HLSStream instances" 90 ) 91 92 93@patch("streamlink.stream.hls.HLSStreamWorker.wait", Mock(return_value=True)) 94class TestHLSStream(TestMixinStreamHLS, unittest.TestCase): 95 def get_session(self, options=None, *args, **kwargs): 96 session = super().get_session(options) 97 session.set_option("hls-live-edge", 3) 98 99 return session 100 101 def test_offset_and_duration(self): 102 thread, segments = self.subject([ 103 Playlist(1234, [Segment(0), Segment(1, duration=0.5), Segment(2, duration=0.5), Segment(3)], end=True) 104 ], streamoptions={"start_offset": 1, "duration": 1}) 105 106 data = self.await_read(read_all=True) 107 self.assertEqual(data, self.content(segments, cond=lambda s: 0 < s.num < 3), "Respects the offset and duration") 108 self.assertTrue(all([self.called(s) for s in segments.values() if 0 < s.num < 3]), "Downloads second and third segment") 109 self.assertFalse(any([self.called(s) for s in segments.values() if 0 > s.num > 3]), "Skips other segments") 110 111 112@patch("streamlink.stream.hls.HLSStreamWorker.wait", Mock(return_value=True)) 113class TestHLSStreamEncrypted(TestMixinStreamHLS, unittest.TestCase): 114 def get_session(self, options=None, *args, **kwargs): 115 session = super().get_session(options) 116 session.set_option("hls-live-edge", 3) 117 118 return session 119 120 def gen_key(self, aes_key=None, aes_iv=None, method="AES-128", uri=None, keyformat="identity", keyformatversions=1): 121 aes_key = aes_key or os.urandom(16) 122 aes_iv = aes_iv or os.urandom(16) 123 124 key = TagKey(method=method, uri=uri, iv=aes_iv, keyformat=keyformat, keyformatversions=keyformatversions) 125 self.mock("GET", key.url(self.id()), content=aes_key) 126 127 return aes_key, aes_iv, key 128 129 def test_hls_encrypted_aes128(self): 130 aesKey, aesIv, key = self.gen_key() 131 132 # noinspection PyTypeChecker 133 thread, segments = self.subject([ 134 Playlist(0, [key] + [SegmentEnc(num, aesKey, aesIv) for num in range(0, 4)]), 135 Playlist(4, [key] + [SegmentEnc(num, aesKey, aesIv) for num in range(4, 8)], end=True) 136 ]) 137 138 data = self.await_read(read_all=True) 139 expected = self.content(segments, prop="content_plain", cond=lambda s: s.num >= 1) 140 self.assertEqual(data, expected, "Decrypts the AES-128 identity stream") 141 self.assertTrue(self.called(key), "Downloads encryption key") 142 self.assertFalse(any([self.called(s) for s in segments.values() if s.num < 1]), "Skips first segment") 143 self.assertTrue(all([self.called(s) for s in segments.values() if s.num >= 1]), "Downloads all remaining segments") 144 145 def test_hls_encrypted_aes128_key_uri_override(self): 146 aesKey, aesIv, key = self.gen_key(uri="http://real-mocked/{namespace}/encryption.key?foo=bar") 147 aesKeyInvalid = bytes([ord(aesKey[i:i + 1]) ^ 0xFF for i in range(16)]) 148 _, __, key_invalid = self.gen_key(aesKeyInvalid, aesIv, uri="http://mocked/{namespace}/encryption.key?foo=bar") 149 150 # noinspection PyTypeChecker 151 thread, segments = self.subject([ 152 Playlist(0, [key_invalid] + [SegmentEnc(num, aesKey, aesIv) for num in range(0, 4)]), 153 Playlist(4, [key_invalid] + [SegmentEnc(num, aesKey, aesIv) for num in range(4, 8)], end=True) 154 ], options={"hls-segment-key-uri": "{scheme}://real-{netloc}{path}?{query}"}) 155 156 data = self.await_read(read_all=True) 157 expected = self.content(segments, prop="content_plain", cond=lambda s: s.num >= 1) 158 self.assertEqual(data, expected, "Decrypts stream from custom key") 159 self.assertFalse(self.called(key_invalid), "Skips encryption key") 160 self.assertTrue(self.called(key), "Downloads custom encryption key") 161 162 163@patch("streamlink.stream.hls.HLSStreamWorker.wait", Mock(return_value=True)) 164@patch("streamlink.stream.hls.HLSStreamWriter.run", Mock(return_value=True)) 165class TestHlsPlaylistReloadTime(TestMixinStreamHLS, unittest.TestCase): 166 segments = [Segment(0, "", 11), Segment(1, "", 7), Segment(2, "", 5), Segment(3, "", 3)] 167 168 def get_session(self, options=None, reload_time=None, *args, **kwargs): 169 return super().get_session(dict(options or {}, **{ 170 "hls-live-edge": 3, 171 "hls-playlist-reload-time": reload_time 172 })) 173 174 def subject(self, *args, **kwargs): 175 thread, _ = super().subject(*args, **kwargs) 176 self.await_read(read_all=True) 177 178 return thread.reader.worker.playlist_reload_time 179 180 def test_hls_playlist_reload_time_default(self): 181 time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="default") 182 self.assertEqual(time, 6, "default sets the reload time to the playlist's target duration") 183 184 def test_hls_playlist_reload_time_segment(self): 185 time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="segment") 186 self.assertEqual(time, 3, "segment sets the reload time to the playlist's last segment") 187 188 def test_hls_playlist_reload_time_live_edge(self): 189 time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="live-edge") 190 self.assertEqual(time, 8, "live-edge sets the reload time to the sum of the number of segments of the live-edge") 191 192 def test_hls_playlist_reload_time_number(self): 193 time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="4") 194 self.assertEqual(time, 4, "number values override the reload time") 195 196 def test_hls_playlist_reload_time_number_invalid(self): 197 time = self.subject([Playlist(0, self.segments, end=True, targetduration=6)], reload_time="0") 198 self.assertEqual(time, 6, "invalid number values set the reload time to the playlist's targetduration") 199 200 def test_hls_playlist_reload_time_no_target_duration(self): 201 time = self.subject([Playlist(0, self.segments, end=True, targetduration=0)], reload_time="default") 202 self.assertEqual(time, 8, "uses the live-edge sum if the playlist is missing the targetduration data") 203 204 def test_hls_playlist_reload_time_no_data(self): 205 time = self.subject([Playlist(0, [], end=True, targetduration=0)], reload_time="default") 206 self.assertEqual(time, 15, "sets reload time to 15 seconds when no data is available") 207 208 209@patch('streamlink.stream.hls.FFMPEGMuxer.is_usable', Mock(return_value=True)) 210class TestHlsExtAudio(unittest.TestCase): 211 @property 212 def playlist(self): 213 with text("hls/test_2.m3u8") as pl: 214 return pl.read() 215 216 def run_streamlink(self, playlist, audio_select=None): 217 streamlink = Streamlink() 218 219 if audio_select: 220 streamlink.set_option("hls-audio-select", audio_select) 221 222 master_stream = hls.HLSStream.parse_variant_playlist(streamlink, playlist) 223 224 return master_stream 225 226 def test_hls_ext_audio_not_selected(self): 227 master_url = "http://mocked/path/master.m3u8" 228 229 with requests_mock.Mocker() as mock: 230 mock.get(master_url, text=self.playlist) 231 master_stream = self.run_streamlink(master_url)['video'] 232 233 with pytest.raises(AttributeError): 234 master_stream.substreams 235 236 assert master_stream.url == 'http://mocked/path/playlist.m3u8' 237 238 def test_hls_ext_audio_en(self): 239 """ 240 m3u8 with ext audio but no options should not download additional streams 241 :return: 242 """ 243 244 master_url = "http://mocked/path/master.m3u8" 245 expected = ['http://mocked/path/playlist.m3u8', 'http://mocked/path/en.m3u8'] 246 247 with requests_mock.Mocker() as mock: 248 mock.get(master_url, text=self.playlist) 249 master_stream = self.run_streamlink(master_url, 'en') 250 251 substreams = master_stream['video'].substreams 252 result = [x.url for x in substreams] 253 254 # Check result 255 self.assertEqual(result, expected) 256 257 def test_hls_ext_audio_es(self): 258 """ 259 m3u8 with ext audio but no options should not download additional streams 260 :return: 261 """ 262 263 master_url = "http://mocked/path/master.m3u8" 264 expected = ['http://mocked/path/playlist.m3u8', 'http://mocked/path/es.m3u8'] 265 266 with requests_mock.Mocker() as mock: 267 mock.get(master_url, text=self.playlist) 268 master_stream = self.run_streamlink(master_url, 'es') 269 270 substreams = master_stream['video'].substreams 271 272 result = [x.url for x in substreams] 273 274 # Check result 275 self.assertEqual(result, expected) 276 277 def test_hls_ext_audio_all(self): 278 """ 279 m3u8 with ext audio but no options should not download additional streams 280 :return: 281 """ 282 283 master_url = "http://mocked/path/master.m3u8" 284 expected = ['http://mocked/path/playlist.m3u8', 'http://mocked/path/en.m3u8', 'http://mocked/path/es.m3u8'] 285 286 with requests_mock.Mocker() as mock: 287 mock.get(master_url, text=self.playlist) 288 master_stream = self.run_streamlink(master_url, 'en,es') 289 290 substreams = master_stream['video'].substreams 291 292 result = [x.url for x in substreams] 293 294 # Check result 295 self.assertEqual(result, expected) 296 297 def test_hls_ext_audio_wildcard(self): 298 master_url = "http://mocked/path/master.m3u8" 299 expected = ['http://mocked/path/playlist.m3u8', 'http://mocked/path/en.m3u8', 'http://mocked/path/es.m3u8'] 300 301 with requests_mock.Mocker() as mock: 302 mock.get(master_url, text=self.playlist) 303 master_stream = self.run_streamlink(master_url, '*') 304 305 substreams = master_stream['video'].substreams 306 307 result = [x.url for x in substreams] 308 309 # Check result 310 self.assertEqual(result, expected) 311