1# -*- coding: utf-8 -*-
2
3import logging
4import os
5import shutil
6import tempfile
7
8from babelfish import language_converters
9import fese
10from fese import check_integrity
11from fese import FFprobeSubtitleStream
12from fese import FFprobeVideoContainer
13from fese import to_srt
14from subliminal.subtitle import fix_line_ending
15from subliminal_patch.core import Episode
16from subliminal_patch.core import Movie
17from subliminal_patch.providers import Provider
18from subliminal_patch.subtitle import Subtitle
19from subzero.language import Language
20
21logger = logging.getLogger(__name__)
22
23# Replace Babelfish's Language with Subzero's Language
24fese.Language = Language
25
26
27class EmbeddedSubtitle(Subtitle):
28    provider_name = "embeddedsubtitles"
29    hash_verifiable = False
30
31    def __init__(self, stream, container, matches):
32        super().__init__(stream.language, stream.disposition.hearing_impaired)
33        if stream.disposition.forced:
34            self.language = Language.rebuild(stream.language, forced=True)
35
36        self.stream: FFprobeSubtitleStream = stream
37        self.container: FFprobeVideoContainer = container
38        self.forced = stream.disposition.forced
39        self._matches: set = matches
40        self.page_link = self.container.path
41        self.release_info = os.path.basename(self.page_link)
42
43    def get_matches(self, video):
44        if self.hearing_impaired:
45            self._matches.add("hearing_impaired")
46
47        self._matches.add("hash")
48        return self._matches
49
50    @property
51    def id(self):
52        return f"{self.container.path}_{self.stream.index}"
53
54
55class EmbeddedSubtitlesProvider(Provider):
56    provider_name = "embeddedsubtitles"
57
58    languages = {Language("por", "BR"), Language("spa", "MX")} | {
59        Language.fromalpha2(l) for l in language_converters["alpha2"].codes
60    }
61    languages.update(set(Language.rebuild(lang, hi=True) for lang in languages))
62    languages.update(set(Language.rebuild(lang, forced=True) for lang in languages))
63
64    video_types = (Episode, Movie)
65    subtitle_class = EmbeddedSubtitle
66
67    def __init__(
68        self,
69        include_ass=True,
70        include_srt=True,
71        cache_dir=None,
72        ffprobe_path=None,
73        ffmpeg_path=None,
74    ):
75        self._include_ass = include_ass
76        self._include_srt = include_srt
77        self._cache_dir = os.path.join(
78            cache_dir or tempfile.gettempdir(), self.__class__.__name__.lower()
79        )
80        self._cached_paths = {}
81
82        fese.FFPROBE_PATH = ffprobe_path or fese.FFPROBE_PATH
83        fese.FFMPEG_PATH = ffmpeg_path or fese.FFMPEG_PATH
84
85        if logger.getEffectiveLevel() == logging.DEBUG:
86            fese.FF_LOG_LEVEL = "warning"
87        else:
88            # Default is True
89            fese.FFMPEG_STATS = False
90
91    def initialize(self):
92        os.makedirs(self._cache_dir, exist_ok=True)
93
94    def terminate(self):
95        # Remove leftovers
96        shutil.rmtree(self._cache_dir, ignore_errors=True)
97
98    def query(self, path: str, languages):
99        video = FFprobeVideoContainer(path)
100
101        try:
102            streams = filter(_check_allowed_extensions, video.get_subtitles())
103        except fese.InvalidSource as error:
104            logger.error("Error trying to get subtitles for %s: %s", video, error)
105            streams = []
106
107        if not streams:
108            logger.debug("No subtitles found for container: %s", video)
109
110        only_forced = all(lang.forced for lang in languages)
111        also_forced = any(lang.forced for lang in languages)
112
113        subtitles = []
114
115        for stream in streams:
116            if not self._include_ass and stream.extension == "ass":
117                logger.debug("Ignoring ASS: %s", stream)
118                continue
119
120            if not self._include_srt and stream.extension == "srt":
121                logger.debug("Ignoring SRT: %s", stream)
122                continue
123
124            if stream.language not in languages:
125                continue
126
127            disposition = stream.disposition
128
129            if only_forced and not disposition.forced:
130                continue
131
132            if (
133                disposition.generic
134                or disposition.hearing_impaired
135                or (disposition.forced and also_forced)
136            ):
137                logger.debug("Appending subtitle: %s", stream)
138                subtitles.append(EmbeddedSubtitle(stream, video, {"hash"}))
139            else:
140                logger.debug("Ignoring unwanted subtitle: %s", stream)
141
142        return subtitles
143
144    def list_subtitles(self, video, languages):
145        if not os.path.isfile(video.original_path):
146            logger.debug("Ignoring inexistent file: %s", video.original_path)
147            return []
148
149        return self.query(video.original_path, languages)
150
151    def download_subtitle(self, subtitle):
152        path = self._get_subtitle_path(subtitle)
153        with open(path, "rb") as sub:
154            content = sub.read()
155            subtitle.content = fix_line_ending(content)
156
157    def _get_subtitle_path(self, subtitle: EmbeddedSubtitle):
158        container = subtitle.container
159
160        # Check if the container is not already in the instance
161        if container.path not in self._cached_paths:
162            # Extract all subittle streams to avoid reading the entire
163            # container over and over
164            streams = filter(_check_allowed_extensions, container.get_subtitles())
165            extracted = container.extract_subtitles(list(streams), self._cache_dir)
166            # Add the extracted paths to the containter path key
167            self._cached_paths[container.path] = extracted
168
169        cached_path = self._cached_paths[container.path]
170        # Get the subtitle file by index
171        subtitle_path = cached_path[subtitle.stream.index]
172
173        check_integrity(subtitle.stream, subtitle_path)
174
175        # Convert to SRT if the subtitle is ASS
176        new_subtitle_path = to_srt(subtitle_path, remove_source=True)
177        if new_subtitle_path != subtitle_path:
178            cached_path[subtitle.stream.index] = new_subtitle_path
179
180        return new_subtitle_path
181
182
183def _check_allowed_extensions(subtitle: FFprobeSubtitleStream):
184    return subtitle.extension in ("ass", "srt")
185