1# Copyright (C) 2008-2010 Adam Olsen
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation; either version 2, or (at your option)
6# any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program; if not, write to the Free Software
15# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
16#
17#
18# The developers of the Exaile media player hereby grant permission
19# for non-GPL compatible GStreamer and Exaile plugins to be used and
20# distributed together with GStreamer and Exaile. This permission is
21# above and beyond the permissions granted by the GPL license by which
22# Exaile is covered. If you modify this code, you may extend this
23# exception to your version of the code, but you are not obligated to
24# do so. If you do not wish to do so, delete this exception statement
25# from your version.
26
27
28from copy import deepcopy
29import logging
30from time import time
31from typing import Dict, Iterable, Iterator, List, Optional, Tuple
32
33from xl import common, event
34from xl.nls import gettext as _
35from xl.trax.track import Track
36
37logger = logging.getLogger(__name__)
38
39
40class TrackHolder:
41    def __init__(self, track, key, **kwargs):
42        self._track = track
43        self._key = key
44        self._attrs = kwargs
45
46    def __getattr__(self, attr):
47        return getattr(self._track, attr)
48
49
50class TrackDBIterator:
51    def __init__(self, track_iterator: Iterator[Tuple[str, TrackHolder]]):
52        self.iter = track_iterator
53
54    def __iter__(self):
55        return self
56
57    def __next__(self):
58        return next(self.iter)[1]._track
59
60
61class TrackDB:
62    """
63    Manages a track database.
64
65    Allows you to add, remove, retrieve, search, save and load
66    Track objects.
67
68    :param name:   The name of this :class:`TrackDB`.
69    :param location:   Path to a file where this :class:`TrackDB`
70            should be stored.
71    :param pickle_attrs:   A list of attributes to store in the
72            pickled representation of this object. All
73            attributes listed must be built-in types, with
74            one exception: If the object contains the phrase
75            'tracks' in its name it may be a list or dict
76            of :class:`Track` objects.
77    :param load_first: Set to True if this collection should be
78            loaded before any tracks are created.
79    """
80
81    def __init__(
82        self,
83        name: str = "",
84        location: str = "",
85        pickle_attrs: List[str] = [],
86        loadfirst: bool = False,
87    ):
88        """
89        Sets up the trackDB.
90        """
91
92        # ensure that the DB is always loaded before any tracks are,
93        # otherwise internal values are not loaded and may be lost/corrupted
94        if loadfirst and Track._get_track_count() != 0:
95            raise RuntimeError(
96                (
97                    "Internal error! %d tracks already loaded, "
98                    + "TrackDB must be loaded first!"
99                )
100                % Track._get_track_count()
101            )
102
103        self.name = name
104        self.location = location
105        self._dirty = False
106        self.tracks: Dict[str, TrackHolder] = {}  # key is URI of the track
107        self.pickle_attrs = pickle_attrs
108        self.pickle_attrs += ['tracks', 'name', '_key']
109        self._saving = False
110        self._key = 0
111        self._dbversion = 2.0
112        self._dbminorversion = 0
113        self._deleted_keys = []
114        if location:
115            self.load_from_location()
116            self._timeout_save()
117
118    def __iter__(self):
119        """
120        Provide the ability to iterate over a TrackDB.
121        Just as with a dictionary, if tracks are added
122        or removed during iteration, iteration will halt
123        wuth a RuntimeError.
124        """
125        track_iterator = iter(self.tracks.items())
126        iterator = TrackDBIterator(track_iterator)
127        return iterator
128
129    def __len__(self):
130        """
131        Obtain a count of how many items are in the TrackDB
132        """
133        return len(self.tracks)
134
135    @common.glib_wait_seconds(300)
136    def _timeout_save(self):
137        """
138        Callback for auto-saving.
139        """
140        self.save_to_location()
141        return True
142
143    def set_name(self, name):
144        """
145        Sets the name of this :class:`TrackDB`
146
147        :param name:   The new name.
148        :type name: string
149        """
150        self.name = name
151        self._dirty = True
152
153    def get_name(self):
154        """
155        Gets the name of this :class:`TrackDB`
156
157        :return: The name.
158        :rtype: string
159        """
160        return self.name
161
162    def set_location(self, location):
163        """
164        Sets the location to save to
165
166        :param location: the location to save to
167        """
168        self.location = location
169        self._dirty = True
170
171    @common.synchronized
172    def load_from_location(self, location=None):
173        """
174        Restores :class:`TrackDB` state from the pickled representation
175        stored at the specified location.
176
177        :param location: the location to load the data from
178        :type location: string
179        """
180        if not location:
181            location = self.location
182        if not location:
183            raise AttributeError(
184                _("You did not specify a location to load the db from")
185            )
186
187        logger.debug("Loading %s DB from %s.", self.name, location)
188
189        pdata = common.open_shelf(location)
190
191        if "_dbversion" in pdata:
192            if int(pdata['_dbversion']) > int(self._dbversion):
193                raise common.VersionError("DB was created on a newer Exaile version.")
194            elif pdata['_dbversion'] < self._dbversion:
195                logger.info("Upgrading DB format....")
196                import shutil
197
198                shutil.copyfile(location, location + "-%s.bak" % pdata['_dbversion'])
199                import xl.migrations.database as dbmig
200
201                dbmig.handle_migration(
202                    self, pdata, pdata['_dbversion'], self._dbversion
203                )
204
205        for attr in self.pickle_attrs:
206            try:
207                if 'tracks' == attr:
208                    data = {}
209                    for k in (x for x in pdata.keys() if x.startswith("tracks-")):
210                        p = pdata[k]
211                        tr = Track(_unpickles=p[0])
212                        loc = tr.get_loc_for_io()
213                        if loc not in data:
214                            data[loc] = TrackHolder(tr, p[1], **p[2])
215                        else:
216                            logger.warning("Duplicate track found: %s", loc)
217                            # presumably the second track was written because of an error,
218                            # so use the first track found.
219                            del pdata[k]
220
221                    setattr(self, attr, data)
222                else:
223                    setattr(self, attr, pdata.get(attr, getattr(self, attr)))
224            except Exception:
225                # FIXME: Do something about this
226                logger.exception("Exception occurred while loading %s", location)
227
228        pdata.close()
229
230        self._dirty = False
231
232    @common.synchronized
233    def save_to_location(self, location=None):
234        """
235        Saves a pickled representation of this :class:`TrackDB` to the
236        specified location.
237
238        :param location: the location to save the data to
239        :type location: string
240        """
241        if not self._dirty:
242            for track in self.tracks.values():
243                if track._track._dirty:
244                    self._dirty = True
245                    break
246
247        if not self._dirty:
248            return
249
250        if not location:
251            location = self.location
252        if not location:
253            raise AttributeError(_("You did not specify a location to save the db"))
254
255        if self._saving:
256            return
257        self._saving = True
258
259        logger.debug("Saving %s DB to %s.", self.name, location)
260
261        try:
262            pdata = common.open_shelf(location)
263            if pdata.get('_dbversion', self._dbversion) > self._dbversion:
264                raise common.VersionError("DB was created on a newer Exaile.")
265        except Exception:
266            logger.exception("Failed to open music DB for writing.")
267            return
268
269        for attr in self.pickle_attrs:
270            # bad hack to allow saving of lists/dicts of Tracks
271            if 'tracks' == attr:
272                for k, track in self.tracks.items():
273                    key = "tracks-%s" % track._key
274                    if track._track._dirty or key not in pdata:
275                        pdata[key] = (
276                            track._track._pickles(),
277                            track._key,
278                            deepcopy(track._attrs),
279                        )
280            else:
281                pdata[attr] = deepcopy(getattr(self, attr))
282
283        pdata['_dbversion'] = self._dbversion
284
285        for key in self._deleted_keys:
286            key = "tracks-%s" % key
287            if key in pdata:
288                del pdata[key]
289
290        pdata.sync()
291        pdata.close()
292
293        for track in self.tracks.values():
294            track._track._dirty = False
295
296        self._dirty = False
297        self._saving = False
298
299    def get_track_by_loc(self, loc: str, raw=False) -> Optional[Track]:
300        """
301        returns the track having the given loc. if no such track exists,
302        returns None
303        """
304        try:
305            return self.tracks[loc]._track
306        except KeyError:
307            return None
308
309    def loc_is_member(self, loc: str) -> bool:
310        """
311        Returns True if loc is a track in this collection, False
312        if it is not
313        """
314        return loc in self.tracks
315
316    def get_count(self) -> int:
317        """
318        Returns the number of tracks stored in this database
319        """
320        return len(self.tracks)
321
322    def add(self, track: Track) -> None:
323        """
324        Adds a track to the database of tracks
325
326        :param track: The :class:`xl.trax.Track` to add
327        """
328        self.add_tracks([track])
329
330    @common.synchronized
331    def add_tracks(self, tracks: Iterable[Track]) -> None:
332        """
333        Like add(), but takes a list of :class:`xl.trax.Track`
334        """
335        locations = []
336        now = time()
337        for tr in tracks:
338            if not tr.get_tag_raw('__date_added'):
339                tr.set_tags(__date_added=now)
340            location = tr.get_loc_for_io()
341            # Don't add duplicates -- track URLs are unique
342            if location in self.tracks:
343                continue
344            locations += [location]
345            self.tracks[location] = TrackHolder(tr, self._key)
346            self._key += 1
347
348        if locations:
349            event.log_event('tracks_added', self, locations)
350            self._dirty = True
351
352    def remove(self, track: Track) -> None:
353        """
354        Removes a track from the database
355
356        :param track: the :class:`xl.trax.Track` to remove
357        """
358        self.remove_tracks([track])
359
360    @common.synchronized
361    def remove_tracks(self, tracks: Iterable[Track]) -> None:
362        """
363        Like remove(), but takes a list of :class:`xl.trax.Track`
364        """
365        locations = []
366
367        for tr in tracks:
368            location = tr.get_loc_for_io()
369            locations += [location]
370            self._deleted_keys.append(self.tracks[location]._key)
371            del self.tracks[location]
372
373        event.log_event('tracks_removed', self, locations)
374
375        self._dirty = True
376
377    def get_tracks(self) -> List[Track]:
378        return list(self)
379