1import copy
2import datetime
3import logging
4import math
5import re
6import time
7from collections import defaultdict, namedtuple
8from contextlib import contextmanager
9from itertools import count, repeat
10from urllib.parse import urljoin, urlparse, urlsplit, urlunparse, urlunsplit
11
12from isodate import Duration, parse_datetime, parse_duration
13
14if hasattr(datetime, "timezone"):
15    utc = datetime.timezone.utc
16else:
17    class UTC(datetime.tzinfo):
18        def utcoffset(self, dt):
19            return datetime.timedelta(0)
20
21        def tzname(self, dt):
22            return "UTC"
23
24        def dst(self, dt):
25            return datetime.timedelta(0)
26
27    utc = UTC()
28
29log = logging.getLogger(__name__)
30epoch_start = datetime.datetime(1970, 1, 1, tzinfo=utc)
31
32
33class Segment:
34    def __init__(self, url, duration, init=False, content=True, available_at=epoch_start, range=None):
35        self.url = url
36        self.duration = duration
37        self.init = init
38        self.content = content
39        self.available_at = available_at
40        self.range = range
41
42
43def datetime_to_seconds(dt):
44    return (dt - epoch_start).total_seconds()
45
46
47def count_dt(firstval=datetime.datetime.now(tz=utc), step=datetime.timedelta(seconds=1)):
48    x = firstval
49    while True:
50        yield x
51        x += step
52
53
54@contextmanager
55def freeze_timeline(mpd):
56    timelines = copy.copy(mpd.timelines)
57    yield
58    mpd.timelines = timelines
59
60
61@contextmanager
62def sleeper(duration):
63    s = time.time()
64    yield
65    time_to_sleep = duration - (time.time() - s)
66    if time_to_sleep > 0:
67        time.sleep(time_to_sleep)
68
69
70def sleep_until(walltime):
71    c = datetime.datetime.now(tz=utc)
72    time_to_wait = (walltime - c).total_seconds()
73    if time_to_wait > 0:
74        time.sleep(time_to_wait)
75
76
77class MPDParsers:
78    @staticmethod
79    def bool_str(v):
80        return v.lower() == "true"
81
82    @staticmethod
83    def type(type_):
84        if type_ not in ("static", "dynamic"):
85            raise MPDParsingError("@type must be static or dynamic")
86        return type_
87
88    @staticmethod
89    def duration(duration):
90        return parse_duration(duration)
91
92    @staticmethod
93    def datetime(dt):
94        return parse_datetime(dt).replace(tzinfo=utc)
95
96    @staticmethod
97    def segment_template(url_template):
98        end = 0
99        res = ""
100        for m in re.compile(r"(.*?)\$(\w+)(?:%([\w.]+))?\$").finditer(url_template):
101            _, end = m.span()
102            res += "{0}{{{1}{2}}}".format(m.group(1),
103                                          m.group(2),
104                                          (":" + m.group(3)) if m.group(3) else "")
105        return (res + url_template[end:]).format
106
107    @staticmethod
108    def frame_rate(frame_rate):
109        if "/" in frame_rate:
110            a, b = frame_rate.split("/")
111            return float(a) / float(b)
112        else:
113            return float(frame_rate)
114
115    @staticmethod
116    def timedelta(timescale=1):
117        def _timedelta(seconds):
118            return datetime.timedelta(seconds=int(float(seconds) / float(timescale)))
119
120        return _timedelta
121
122    @staticmethod
123    def range(range_spec):
124        r = range_spec.split("-")
125        if len(r) != 2:
126            raise MPDParsingError("invalid byte-range-spec")
127
128        start, end = int(r[0]), r[1] and int(r[1]) or None
129        return start, end and ((end - start) + 1)
130
131
132class MPDParsingError(Exception):
133    pass
134
135
136class MPDNode:
137    __tag__ = None
138
139    def __init__(self, node, root=None, parent=None, *args, **kwargs):
140        self.node = node
141        self.root = root
142        self.parent = parent
143        self._base_url = kwargs.get("base_url")
144        self.attributes = set()
145        if self.__tag__ and self.node.tag.lower() != self.__tag__.lower():
146            raise MPDParsingError("root tag did not match the expected tag: {}".format(self.__tag__))
147
148    @property
149    def attrib(self):
150        return self.node.attrib
151
152    @property
153    def text(self):
154        return self.node.text
155
156    def __str__(self):
157        return "<{tag} {attrs}>".format(
158            tag=self.__tag__,
159            attrs=" ".join("@{}={}".format(attr, getattr(self, attr)) for attr in self.attributes)
160        )
161
162    def attr(self, key, default=None, parser=None, required=False, inherited=False):
163        self.attributes.add(key)
164        if key in self.attrib:
165            value = self.attrib.get(key)
166            if parser and callable(parser):
167                return parser(value)
168            else:
169                return value
170        elif inherited:
171            if self.parent and hasattr(self.parent, key) and getattr(self.parent, key):
172                return getattr(self.parent, key)
173
174        if required:
175            raise MPDParsingError("could not find required attribute {tag}@{attr} ".format(attr=key, tag=self.__tag__))
176        else:
177            return default
178
179    def children(self, cls, minimum=0, maximum=None):
180
181        children = self.node.findall(cls.__tag__)
182        if len(children) < minimum or (maximum and len(children) > maximum):
183            raise MPDParsingError("expected to find {}/{} required [{}..{})".format(
184                self.__tag__, cls.__tag__, minimum, maximum or "unbound"))
185
186        return list(map(lambda x: cls(x[1], root=self.root, parent=self, i=x[0], base_url=self.base_url),
187                        enumerate(children)))
188
189    def only_child(self, cls, minimum=0):
190        children = self.children(cls, minimum=minimum, maximum=1)
191        return children[0] if len(children) else None
192
193    def walk_back(self, cls=None, f=lambda x: x):
194        node = self.parent
195        while node:
196            if cls is None or cls.__tag__ == node.__tag__:
197                yield f(node)
198            node = node.parent
199
200    def walk_back_get_attr(self, attr):
201        parent_attrs = [getattr(n, attr) for n in self.walk_back() if hasattr(n, attr)]
202        return parent_attrs[0] if len(parent_attrs) else None
203
204    @property
205    def base_url(self):
206        base_url = self._base_url
207        if hasattr(self, "baseURLs") and len(self.baseURLs):
208            base_url = BaseURL.join(base_url, self.baseURLs[0].url)
209        return base_url
210
211
212class MPD(MPDNode):
213    """
214    Represents the MPD as a whole
215
216    Should validate the XML input and provide methods to get segment URLs for each Period, AdaptationSet and
217    Representation.
218
219    """
220    __tag__ = "MPD"
221
222    def __init__(self, node, root=None, parent=None, url=None, *args, **kwargs):
223        # top level has no parent
224        super().__init__(node, root=self, *args, **kwargs)
225        # parser attributes
226        self.url = url
227        self.timelines = defaultdict(lambda: -1)
228        self.timelines.update(kwargs.pop("timelines", {}))
229        self.id = self.attr("id")
230        self.profiles = self.attr("profiles", required=True)
231        self.type = self.attr("type", default="static", parser=MPDParsers.type)
232        self.minimumUpdatePeriod = self.attr("minimumUpdatePeriod", parser=MPDParsers.duration, default=Duration())
233        self.minBufferTime = self.attr("minBufferTime", parser=MPDParsers.duration, required=True)
234        self.timeShiftBufferDepth = self.attr("timeShiftBufferDepth", parser=MPDParsers.duration)
235        self.availabilityStartTime = self.attr("availabilityStartTime", parser=MPDParsers.datetime,
236                                               default=datetime.datetime.fromtimestamp(0, utc),  # earliest date
237                                               required=self.type == "dynamic")
238        self.publishTime = self.attr("publishTime", parser=MPDParsers.datetime, required=self.type == "dynamic")
239        self.availabilityEndTime = self.attr("availabilityEndTime", parser=MPDParsers.datetime)
240        self.mediaPresentationDuration = self.attr("mediaPresentationDuration", parser=MPDParsers.duration)
241        self.suggestedPresentationDelay = self.attr("suggestedPresentationDelay", parser=MPDParsers.duration)
242
243        # parse children
244        location = self.children(Location)
245        self.location = location[0] if location else None
246        if self.location:
247            self.url = self.location.text
248            urlp = list(urlparse(self.url))
249            if urlp[2]:
250                urlp[2], _ = urlp[2].rsplit("/", 1)
251            self._base_url = urlunparse(urlp)
252
253        self.baseURLs = self.children(BaseURL)
254        self.periods = self.children(Period, minimum=1)
255        self.programInformation = self.children(ProgramInformation)
256
257
258class ProgramInformation(MPDNode):
259    __tag__ = "ProgramInformation"
260
261
262class BaseURL(MPDNode):
263    __tag__ = "BaseURL"
264
265    def __init__(self, node, root=None, parent=None, *args, **kwargs):
266        super().__init__(node, root, parent, *args, **kwargs)
267        self.url = self.node.text.strip()
268
269    @property
270    def is_absolute(self):
271        return urlparse(self.url).scheme
272
273    @staticmethod
274    def join(url, other):
275        # if the other URL is an absolute url, then return that
276        if urlparse(other).scheme:
277            return other
278        elif url:
279            parts = list(urlsplit(url))
280            if not parts[2].endswith("/"):
281                parts[2] += "/"
282            url = urlunsplit(parts)
283            return urljoin(url, other)
284        else:
285            return other
286
287
288class Location(MPDNode):
289    __tag__ = "Location"
290
291
292class Period(MPDNode):
293    __tag__ = "Period"
294
295    def __init__(self, node, root=None, parent=None, *args, **kwargs):
296        super().__init__(node, root, parent, *args, **kwargs)
297        self.i = kwargs.get("i", 0)
298        self.id = self.attr("id")
299        self.bitstreamSwitching = self.attr("bitstreamSwitching", parser=MPDParsers.bool_str)
300        self.duration = self.attr("duration", default=Duration(), parser=MPDParsers.duration)
301        self.start = self.attr("start", default=Duration(), parser=MPDParsers.duration)
302
303        if self.start is None and self.i == 0 and self.root.type == "static":
304            self.start = 0
305
306        # TODO: Early Access Periods
307
308        self.baseURLs = self.children(BaseURL)
309        self.segmentBase = self.only_child(SegmentBase)
310        self.adaptationSets = self.children(AdaptationSet, minimum=1)
311        self.segmentList = self.only_child(SegmentList)
312        self.segmentTemplate = self.only_child(SegmentTemplate)
313        self.sssetIdentifier = self.only_child(AssetIdentifier)
314        self.eventStream = self.children(EventStream)
315        self.subset = self.children(Subset)
316
317
318class SegmentBase(MPDNode):
319    __tag__ = "SegmentBase"
320
321
322class AssetIdentifier(MPDNode):
323    __tag__ = "AssetIdentifier"
324
325
326class Subset(MPDNode):
327    __tag__ = "Subset"
328
329
330class EventStream(MPDNode):
331    __tag__ = "EventStream"
332
333
334class Initialization(MPDNode):
335    __tag__ = "Initialization"
336
337    def __init__(self, node, root=None, parent=None, *args, **kwargs):
338        super().__init__(node, root, parent, *args, **kwargs)
339        self.source_url = self.attr("sourceURL")
340
341
342class SegmentURL(MPDNode):
343    __tag__ = "SegmentURL"
344
345    def __init__(self, node, root=None, parent=None, *args, **kwargs):
346        super().__init__(node, root, parent, *args, **kwargs)
347        self.media = self.attr("media")
348        self.media_range = self.attr("mediaRange", parser=MPDParsers.range)
349
350
351class SegmentList(MPDNode):
352    __tag__ = "SegmentList"
353
354    def __init__(self, node, root=None, parent=None, *args, **kwargs):
355        super().__init__(node, root, parent, *args, **kwargs)
356
357        self.presentation_time_offset = self.attr("presentationTimeOffset")
358        self.timescale = self.attr("timescale", parser=int)
359        self.duration = self.attr("duration", parser=int)
360        self.start_number = self.attr("startNumber", parser=int, default=1)
361
362        if self.duration:
363            self.duration_seconds = self.duration / float(self.timescale)
364        else:
365            self.duration_seconds = None
366
367        self.initialization = self.only_child(Initialization)
368        self.segment_urls = self.children(SegmentURL, minimum=1)
369
370    @property
371    def segments(self):
372        if self.initialization:
373            yield Segment(self.make_url(self.initialization.source_url), 0, init=True, content=False)
374        for n, segment_url in enumerate(self.segment_urls, self.start_number):
375            yield Segment(self.make_url(segment_url.media), self.duration_seconds, range=segment_url.media_range)
376
377    def make_url(self, url):
378        return BaseURL.join(self.base_url, url)
379
380
381class AdaptationSet(MPDNode):
382    __tag__ = "AdaptationSet"
383
384    def __init__(self, node, root=None, parent=None, *args, **kwargs):
385        super().__init__(node, root, parent, *args, **kwargs)
386
387        self.id = self.attr("id")
388        self.group = self.attr("group")
389        self.mimeType = self.attr("mimeType")
390        self.lang = self.attr("lang")
391        self.contentType = self.attr("contentType")
392        self.par = self.attr("par")
393        self.minBandwidth = self.attr("minBandwidth")
394        self.maxBandwidth = self.attr("maxBandwidth")
395        self.minWidth = self.attr("minWidth", parser=int)
396        self.maxWidth = self.attr("maxWidth", parser=int)
397        self.minHeight = self.attr("minHeight", parser=int)
398        self.maxHeight = self.attr("maxHeight", parser=int)
399        self.minFrameRate = self.attr("minFrameRate", parser=MPDParsers.frame_rate)
400        self.maxFrameRate = self.attr("maxFrameRate", parser=MPDParsers.frame_rate)
401        self.segmentAlignment = self.attr("segmentAlignment", default=False, parser=MPDParsers.bool_str)
402        self.bitstreamSwitching = self.attr("bitstreamSwitching", parser=MPDParsers.bool_str)
403        self.subsegmentAlignment = self.attr("subsegmentAlignment", default=False, parser=MPDParsers.bool_str)
404        self.subsegmentStartsWithSAP = self.attr("subsegmentStartsWithSAP", default=0, parser=int)
405
406        self.baseURLs = self.children(BaseURL)
407        self.segmentTemplate = self.only_child(SegmentTemplate)
408        self.representations = self.children(Representation, minimum=1)
409        self.contentProtection = self.children(ContentProtection)
410
411
412class SegmentTemplate(MPDNode):
413    __tag__ = "SegmentTemplate"
414
415    def __init__(self, node, root=None, parent=None, *args, **kwargs):
416        super().__init__(node, root, parent, *args, **kwargs)
417        self.defaultSegmentTemplate = self.walk_back_get_attr('segmentTemplate')
418
419        self.initialization = self.attr("initialization", parser=MPDParsers.segment_template)
420        self.media = self.attr("media", parser=MPDParsers.segment_template)
421        self.duration = self.attr("duration", parser=int,
422                                  default=self.defaultSegmentTemplate.duration if self.defaultSegmentTemplate else None)
423        self.timescale = self.attr("timescale", parser=int,
424                                   default=self.defaultSegmentTemplate.timescale if self.defaultSegmentTemplate else 1)
425        self.startNumber = self.attr("startNumber", parser=int,
426                                     default=self.defaultSegmentTemplate.startNumber if self.defaultSegmentTemplate else 1)
427        self.presentationTimeOffset = self.attr("presentationTimeOffset", parser=MPDParsers.timedelta(self.timescale))
428
429        if self.duration:
430            self.duration_seconds = self.duration / float(self.timescale)
431        else:
432            self.duration_seconds = None
433
434        self.period = list(self.walk_back(Period))[0]
435
436        # children
437        self.segmentTimeline = self.only_child(SegmentTimeline)
438
439    def segments(self, **kwargs):
440        if kwargs.pop("init", True):
441            init_url = self.format_initialization(**kwargs)
442            if init_url:
443                yield Segment(init_url, 0, True, False)
444        for media_url, available_at in self.format_media(**kwargs):
445            yield Segment(media_url, self.duration_seconds, False, True, available_at)
446
447    def make_url(self, url):
448        """
449        Join the URL with the base URL, unless it's an absolute URL
450        :param url: maybe relative URL
451        :return: joined URL
452        """
453        return BaseURL.join(self.base_url, url)
454
455    def format_initialization(self, **kwargs):
456        if self.initialization:
457            return self.make_url(self.initialization(**kwargs))
458
459    def segment_numbers(self):
460        """
461        yield the segment number and when it will be available
462        There are two cases for segment number generation, static and dynamic.
463
464        In the case of static stream, the segment number starts at the startNumber and counts
465        up to the number of segments that are represented by the periods duration.
466
467        In the case of dynamic streams, the segments should appear at the specified time
468        in the simplest case the segment number is based on the time since the availabilityStartTime
469        :return:
470        """
471        log.debug("Generating segment numbers for {0} playlist (id={1})".format(self.root.type, self.parent.id))
472        if self.root.type == "static":
473            available_iter = repeat(epoch_start)
474            duration = self.period.duration.seconds or self.root.mediaPresentationDuration.seconds
475            if duration:
476                number_iter = range(self.startNumber, int(duration / self.duration_seconds) + 1)
477            else:
478                number_iter = count(self.startNumber)
479        else:
480            now = datetime.datetime.now(utc)
481            if self.presentationTimeOffset:
482                since_start = (now - self.presentationTimeOffset) - self.root.availabilityStartTime
483                available_start_date = self.root.availabilityStartTime + self.presentationTimeOffset + since_start
484                available_start = available_start_date
485            else:
486                since_start = now - self.root.availabilityStartTime
487                available_start = now
488
489            # if there is no delay, use a delay of 3 seconds
490            suggested_delay = datetime.timedelta(seconds=(self.root.suggestedPresentationDelay.total_seconds()
491                                                          if self.root.suggestedPresentationDelay
492                                                          else 3))
493
494            # the number of the segment that is available at NOW - SUGGESTED_DELAY - BUFFER_TIME
495            number_iter = count(
496                self.startNumber
497                + int(
498                    (since_start - suggested_delay - self.root.minBufferTime).total_seconds()
499                    / self.duration_seconds
500                )
501            )
502
503            # the time the segment number is available at NOW
504            available_iter = count_dt(available_start,
505                                      step=datetime.timedelta(seconds=self.duration_seconds))
506
507        yield from zip(number_iter, available_iter)
508
509    def format_media(self, **kwargs):
510        if self.segmentTimeline:
511            if self.parent.id is None:
512                # workaround for invalid `self.root.timelines[self.parent.id]`
513                # creates a timeline for every mimeType instead of one for both
514                self.parent.id = self.parent.mimeType
515            log.debug("Generating segment timeline for {0} playlist (id={1}))".format(self.root.type, self.parent.id))
516            if self.root.type == "dynamic":
517                # if there is no delay, use a delay of 3 seconds
518                suggested_delay = datetime.timedelta(seconds=(self.root.suggestedPresentationDelay.total_seconds()
519                                                              if self.root.suggestedPresentationDelay
520                                                              else 3))
521                publish_time = self.root.publishTime or epoch_start
522
523                # transform the time line in to a segment list
524                timeline = []
525                available_at = publish_time
526                for segment, n in reversed(list(zip(self.segmentTimeline.segments, count(self.startNumber)))):
527                    # the last segment in the timeline is the most recent
528                    # so, work backwards and calculate when each of the segments was
529                    # available, based on the durations relative to the publish time
530                    url = self.make_url(self.media(Time=segment.t, Number=n, **kwargs))
531                    duration = datetime.timedelta(seconds=segment.d / self.timescale)
532
533                    # once the suggested_delay is reach stop
534                    if self.root.timelines[self.parent.id] == -1 and publish_time - available_at >= suggested_delay:
535                        break
536
537                    timeline.append((url, available_at, segment.t))
538
539                    available_at -= duration  # walk backwards in time
540
541                # return the segments in chronological order
542                for url, available_at, t in reversed(timeline):
543                    if t > self.root.timelines[self.parent.id]:
544                        self.root.timelines[self.parent.id] = t
545                        yield (url, available_at)
546
547            else:
548                for segment, n in zip(self.segmentTimeline.segments, count(self.startNumber)):
549                    yield (self.make_url(self.media(Time=segment.t, Number=n, **kwargs)),
550                           datetime.datetime.now(tz=utc))
551
552        else:
553            for number, available_at in self.segment_numbers():
554                yield (self.make_url(self.media(Number=number, **kwargs)),
555                       available_at)
556
557
558class Representation(MPDNode):
559    __tag__ = "Representation"
560
561    def __init__(self, node, root=None, parent=None, *args, **kwargs):
562        super().__init__(node, root, parent, *args, **kwargs)
563        self.id = self.attr("id", required=True)
564        self.bandwidth = self.attr("bandwidth", parser=lambda b: float(b) / 1000.0, required=True)
565        self.mimeType = self.attr("mimeType", required=True, inherited=True)
566
567        self.codecs = self.attr("codecs")
568        self.startWithSAP = self.attr("startWithSAP")
569
570        # video
571        self.width = self.attr("width", parser=int)
572        self.height = self.attr("height", parser=int)
573        self.frameRate = self.attr("frameRate", parser=MPDParsers.frame_rate)
574
575        # audio
576        self.audioSamplingRate = self.attr("audioSamplingRate", parser=int)
577        self.numChannels = self.attr("numChannels", parser=int)
578
579        # subtitle
580        self.lang = self.attr("lang", inherited=True)
581
582        self.baseURLs = self.children(BaseURL)
583        self.subRepresentation = self.children(SubRepresentation)
584        self.segmentBase = self.only_child(SegmentBase)
585        self.segmentList = self.children(SegmentList)
586        self.segmentTemplate = self.only_child(SegmentTemplate)
587
588    @property
589    def bandwidth_rounded(self):
590        return round(self.bandwidth, 1 - int(math.log10(self.bandwidth)))
591
592    def segments(self, **kwargs):
593        """
594        Segments are yielded when they are available
595
596        Segments appear on a time line, for dynamic content they are only available at a certain time
597        and sometimes for a limited time. For static content they are all available at the same time.
598
599        :param kwargs: extra args to pass to the segment template
600        :return: yields Segments
601        """
602
603        # segmentBase = self.segmentBase or self.walk_back_get_attr("segmentBase")
604        segmentLists = self.segmentList or self.walk_back_get_attr("segmentList")
605        segmentTemplate = self.segmentTemplate or self.walk_back_get_attr("segmentTemplate")
606
607        if segmentTemplate:
608            for segment in segmentTemplate.segments(RepresentationID=self.id,
609                                                    Bandwidth=int(self.bandwidth * 1000),
610                                                    **kwargs):
611                if segment.init:
612                    yield segment
613                else:
614                    yield segment
615        elif segmentLists:
616            for segmentList in segmentLists:
617                for segment in segmentList.segments:
618                    yield segment
619        else:
620            yield Segment(self.base_url, 0, True, True)
621
622
623class SubRepresentation(MPDNode):
624    __tag__ = "SubRepresentation"
625
626
627class SegmentTimeline(MPDNode):
628    __tag__ = "SegmentTimeline"
629    TimelineSegment = namedtuple("TimelineSegment", "t d")
630
631    def __init__(self, node, *args, **kwargs):
632        super().__init__(node, *args, **kwargs)
633
634        self.timescale = self.walk_back_get_attr("timescale")
635
636        self.timeline_segments = self.children(_TimelineSegment)
637
638    @property
639    def segments(self):
640        t = 0
641        for tsegment in self.timeline_segments:
642            if t == 0 and tsegment.t is not None:
643                t = tsegment.t
644            # check the start time from MPD
645            for repeated_i in range(tsegment.r + 1):
646                yield self.TimelineSegment(t, tsegment.d)
647                t += tsegment.d
648
649
650class _TimelineSegment(MPDNode):
651    __tag__ = "S"
652
653    def __init__(self, node, *args, **kwargs):
654        super().__init__(node, *args, **kwargs)
655
656        self.t = self.attr("t", parser=int)
657        self.d = self.attr("d", parser=int)
658        self.r = self.attr("r", parser=int, default=0)
659
660
661class ContentProtection(MPDNode):
662    __tag__ = "ContentProtection"
663
664    def __init__(self, node, root=None, parent=None, *args, **kwargs):
665        super().__init__(node, root, parent, *args, **kwargs)
666
667        self.schemeIdUri = self.attr("schemeIdUri")
668        self.value = self.attr("value")
669        self.default_KID = self.attr("default_KID")
670