1# -*- coding: utf-8 -*-
2# This file is part of beets.
3# Copyright 2016, Adrian Sampson.
4#
5# Permission is hereby granted, free of charge, to any person obtaining
6# a copy of this software and associated documentation files (the
7# "Software"), to deal in the Software without restriction, including
8# without limitation the rights to use, copy, modify, merge, publish,
9# distribute, sublicense, and/or sell copies of the Software, and to
10# permit persons to whom the Software is furnished to do so, subject to
11# the following conditions:
12#
13# The above copyright notice and this permission notice shall be
14# included in all copies or substantial portions of the Software.
15
16"""Glue between metadata sources and the matching logic."""
17from __future__ import division, absolute_import, print_function
18
19from collections import namedtuple
20from functools import total_ordering
21import re
22
23from beets import logging
24from beets import plugins
25from beets import config
26from beets.util import as_string
27from beets.autotag import mb
28from jellyfish import levenshtein_distance
29from unidecode import unidecode
30import six
31
32log = logging.getLogger('beets')
33
34# The name of the type for patterns in re changed in Python 3.7.
35try:
36    Pattern = re._pattern_type
37except AttributeError:
38    Pattern = re.Pattern
39
40
41# Classes used to represent candidate options.
42
43class AlbumInfo(object):
44    """Describes a canonical release that may be used to match a release
45    in the library. Consists of these data members:
46
47    - ``album``: the release title
48    - ``album_id``: MusicBrainz ID; UUID fragment only
49    - ``artist``: name of the release's primary artist
50    - ``artist_id``
51    - ``tracks``: list of TrackInfo objects making up the release
52    - ``asin``: Amazon ASIN
53    - ``albumtype``: string describing the kind of release
54    - ``va``: boolean: whether the release has "various artists"
55    - ``year``: release year
56    - ``month``: release month
57    - ``day``: release day
58    - ``label``: music label responsible for the release
59    - ``mediums``: the number of discs in this release
60    - ``artist_sort``: name of the release's artist for sorting
61    - ``releasegroup_id``: MBID for the album's release group
62    - ``catalognum``: the label's catalog number for the release
63    - ``script``: character set used for metadata
64    - ``language``: human language of the metadata
65    - ``country``: the release country
66    - ``albumstatus``: MusicBrainz release status (Official, etc.)
67    - ``media``: delivery mechanism (Vinyl, etc.)
68    - ``albumdisambig``: MusicBrainz release disambiguation comment
69    - ``releasegroupdisambig``: MusicBrainz release group
70            disambiguation comment.
71    - ``artist_credit``: Release-specific artist name
72    - ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
73    - ``data_url``: The data source release URL.
74
75    ``mediums`` along with the fields up through ``tracks`` are required.
76    The others are optional and may be None.
77    """
78    def __init__(self, album, album_id, artist, artist_id, tracks, asin=None,
79                 albumtype=None, va=False, year=None, month=None, day=None,
80                 label=None, mediums=None, artist_sort=None,
81                 releasegroup_id=None, catalognum=None, script=None,
82                 language=None, country=None, albumstatus=None, media=None,
83                 albumdisambig=None, releasegroupdisambig=None,
84                 artist_credit=None, original_year=None, original_month=None,
85                 original_day=None, data_source=None, data_url=None):
86        self.album = album
87        self.album_id = album_id
88        self.artist = artist
89        self.artist_id = artist_id
90        self.tracks = tracks
91        self.asin = asin
92        self.albumtype = albumtype
93        self.va = va
94        self.year = year
95        self.month = month
96        self.day = day
97        self.label = label
98        self.mediums = mediums
99        self.artist_sort = artist_sort
100        self.releasegroup_id = releasegroup_id
101        self.catalognum = catalognum
102        self.script = script
103        self.language = language
104        self.country = country
105        self.albumstatus = albumstatus
106        self.media = media
107        self.albumdisambig = albumdisambig
108        self.releasegroupdisambig = releasegroupdisambig
109        self.artist_credit = artist_credit
110        self.original_year = original_year
111        self.original_month = original_month
112        self.original_day = original_day
113        self.data_source = data_source
114        self.data_url = data_url
115
116    # Work around a bug in python-musicbrainz-ngs that causes some
117    # strings to be bytes rather than Unicode.
118    # https://github.com/alastair/python-musicbrainz-ngs/issues/85
119    def decode(self, codec='utf-8'):
120        """Ensure that all string attributes on this object, and the
121        constituent `TrackInfo` objects, are decoded to Unicode.
122        """
123        for fld in ['album', 'artist', 'albumtype', 'label', 'artist_sort',
124                    'catalognum', 'script', 'language', 'country',
125                    'albumstatus', 'albumdisambig', 'releasegroupdisambig',
126                    'artist_credit', 'media']:
127            value = getattr(self, fld)
128            if isinstance(value, bytes):
129                setattr(self, fld, value.decode(codec, 'ignore'))
130
131        if self.tracks:
132            for track in self.tracks:
133                track.decode(codec)
134
135
136class TrackInfo(object):
137    """Describes a canonical track present on a release. Appears as part
138    of an AlbumInfo's ``tracks`` list. Consists of these data members:
139
140    - ``title``: name of the track
141    - ``track_id``: MusicBrainz ID; UUID fragment only
142    - ``release_track_id``: MusicBrainz ID respective to a track on a
143            particular release; UUID fragment only
144    - ``artist``: individual track artist name
145    - ``artist_id``
146    - ``length``: float: duration of the track in seconds
147    - ``index``: position on the entire release
148    - ``media``: delivery mechanism (Vinyl, etc.)
149    - ``medium``: the disc number this track appears on in the album
150    - ``medium_index``: the track's position on the disc
151    - ``medium_total``: the number of tracks on the item's disc
152    - ``artist_sort``: name of the track artist for sorting
153    - ``disctitle``: name of the individual medium (subtitle)
154    - ``artist_credit``: Recording-specific artist name
155    - ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
156    - ``data_url``: The data source release URL.
157    - ``lyricist``: individual track lyricist name
158    - ``composer``: individual track composer name
159    - ``composer_sort``: individual track composer sort name
160    - ``arranger`: individual track arranger name
161    - ``track_alt``: alternative track number (tape, vinyl, etc.)
162
163    Only ``title`` and ``track_id`` are required. The rest of the fields
164    may be None. The indices ``index``, ``medium``, and ``medium_index``
165    are all 1-based.
166    """
167    def __init__(self, title, track_id, release_track_id=None, artist=None,
168                 artist_id=None, length=None, index=None, medium=None,
169                 medium_index=None, medium_total=None, artist_sort=None,
170                 disctitle=None, artist_credit=None, data_source=None,
171                 data_url=None, media=None, lyricist=None, composer=None,
172                 composer_sort=None, arranger=None, track_alt=None):
173        self.title = title
174        self.track_id = track_id
175        self.release_track_id = release_track_id
176        self.artist = artist
177        self.artist_id = artist_id
178        self.length = length
179        self.index = index
180        self.media = media
181        self.medium = medium
182        self.medium_index = medium_index
183        self.medium_total = medium_total
184        self.artist_sort = artist_sort
185        self.disctitle = disctitle
186        self.artist_credit = artist_credit
187        self.data_source = data_source
188        self.data_url = data_url
189        self.lyricist = lyricist
190        self.composer = composer
191        self.composer_sort = composer_sort
192        self.arranger = arranger
193        self.track_alt = track_alt
194
195    # As above, work around a bug in python-musicbrainz-ngs.
196    def decode(self, codec='utf-8'):
197        """Ensure that all string attributes on this object are decoded
198        to Unicode.
199        """
200        for fld in ['title', 'artist', 'medium', 'artist_sort', 'disctitle',
201                    'artist_credit', 'media']:
202            value = getattr(self, fld)
203            if isinstance(value, bytes):
204                setattr(self, fld, value.decode(codec, 'ignore'))
205
206
207# Candidate distance scoring.
208
209# Parameters for string distance function.
210# Words that can be moved to the end of a string using a comma.
211SD_END_WORDS = ['the', 'a', 'an']
212# Reduced weights for certain portions of the string.
213SD_PATTERNS = [
214    (r'^the ', 0.1),
215    (r'[\[\(]?(ep|single)[\]\)]?', 0.0),
216    (r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1),
217    (r'\(.*?\)', 0.3),
218    (r'\[.*?\]', 0.3),
219    (r'(, )?(pt\.|part) .+', 0.2),
220]
221# Replacements to use before testing distance.
222SD_REPLACE = [
223    (r'&', 'and'),
224]
225
226
227def _string_dist_basic(str1, str2):
228    """Basic edit distance between two strings, ignoring
229    non-alphanumeric characters and case. Comparisons are based on a
230    transliteration/lowering to ASCII characters. Normalized by string
231    length.
232    """
233    assert isinstance(str1, six.text_type)
234    assert isinstance(str2, six.text_type)
235    str1 = as_string(unidecode(str1))
236    str2 = as_string(unidecode(str2))
237    str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
238    str2 = re.sub(r'[^a-z0-9]', '', str2.lower())
239    if not str1 and not str2:
240        return 0.0
241    return levenshtein_distance(str1, str2) / float(max(len(str1), len(str2)))
242
243
244def string_dist(str1, str2):
245    """Gives an "intuitive" edit distance between two strings. This is
246    an edit distance, normalized by the string length, with a number of
247    tweaks that reflect intuition about text.
248    """
249    if str1 is None and str2 is None:
250        return 0.0
251    if str1 is None or str2 is None:
252        return 1.0
253
254    str1 = str1.lower()
255    str2 = str2.lower()
256
257    # Don't penalize strings that move certain words to the end. For
258    # example, "the something" should be considered equal to
259    # "something, the".
260    for word in SD_END_WORDS:
261        if str1.endswith(', %s' % word):
262            str1 = '%s %s' % (word, str1[:-len(word) - 2])
263        if str2.endswith(', %s' % word):
264            str2 = '%s %s' % (word, str2[:-len(word) - 2])
265
266    # Perform a couple of basic normalizing substitutions.
267    for pat, repl in SD_REPLACE:
268        str1 = re.sub(pat, repl, str1)
269        str2 = re.sub(pat, repl, str2)
270
271    # Change the weight for certain string portions matched by a set
272    # of regular expressions. We gradually change the strings and build
273    # up penalties associated with parts of the string that were
274    # deleted.
275    base_dist = _string_dist_basic(str1, str2)
276    penalty = 0.0
277    for pat, weight in SD_PATTERNS:
278        # Get strings that drop the pattern.
279        case_str1 = re.sub(pat, '', str1)
280        case_str2 = re.sub(pat, '', str2)
281
282        if case_str1 != str1 or case_str2 != str2:
283            # If the pattern was present (i.e., it is deleted in the
284            # the current case), recalculate the distances for the
285            # modified strings.
286            case_dist = _string_dist_basic(case_str1, case_str2)
287            case_delta = max(0.0, base_dist - case_dist)
288            if case_delta == 0.0:
289                continue
290
291            # Shift our baseline strings down (to avoid rematching the
292            # same part of the string) and add a scaled distance
293            # amount to the penalties.
294            str1 = case_str1
295            str2 = case_str2
296            base_dist = case_dist
297            penalty += weight * case_delta
298
299    return base_dist + penalty
300
301
302class LazyClassProperty(object):
303    """A decorator implementing a read-only property that is *lazy* in
304    the sense that the getter is only invoked once. Subsequent accesses
305    through *any* instance use the cached result.
306    """
307    def __init__(self, getter):
308        self.getter = getter
309        self.computed = False
310
311    def __get__(self, obj, owner):
312        if not self.computed:
313            self.value = self.getter(owner)
314            self.computed = True
315        return self.value
316
317
318@total_ordering
319@six.python_2_unicode_compatible
320class Distance(object):
321    """Keeps track of multiple distance penalties. Provides a single
322    weighted distance for all penalties as well as a weighted distance
323    for each individual penalty.
324    """
325    def __init__(self):
326        self._penalties = {}
327
328    @LazyClassProperty
329    def _weights(cls):  # noqa
330        """A dictionary from keys to floating-point weights.
331        """
332        weights_view = config['match']['distance_weights']
333        weights = {}
334        for key in weights_view.keys():
335            weights[key] = weights_view[key].as_number()
336        return weights
337
338    # Access the components and their aggregates.
339
340    @property
341    def distance(self):
342        """Return a weighted and normalized distance across all
343        penalties.
344        """
345        dist_max = self.max_distance
346        if dist_max:
347            return self.raw_distance / self.max_distance
348        return 0.0
349
350    @property
351    def max_distance(self):
352        """Return the maximum distance penalty (normalization factor).
353        """
354        dist_max = 0.0
355        for key, penalty in self._penalties.items():
356            dist_max += len(penalty) * self._weights[key]
357        return dist_max
358
359    @property
360    def raw_distance(self):
361        """Return the raw (denormalized) distance.
362        """
363        dist_raw = 0.0
364        for key, penalty in self._penalties.items():
365            dist_raw += sum(penalty) * self._weights[key]
366        return dist_raw
367
368    def items(self):
369        """Return a list of (key, dist) pairs, with `dist` being the
370        weighted distance, sorted from highest to lowest. Does not
371        include penalties with a zero value.
372        """
373        list_ = []
374        for key in self._penalties:
375            dist = self[key]
376            if dist:
377                list_.append((key, dist))
378        # Convert distance into a negative float we can sort items in
379        # ascending order (for keys, when the penalty is equal) and
380        # still get the items with the biggest distance first.
381        return sorted(
382            list_,
383            key=lambda key_and_dist: (-key_and_dist[1], key_and_dist[0])
384        )
385
386    def __hash__(self):
387        return id(self)
388
389    def __eq__(self, other):
390        return self.distance == other
391
392    # Behave like a float.
393
394    def __lt__(self, other):
395        return self.distance < other
396
397    def __float__(self):
398        return self.distance
399
400    def __sub__(self, other):
401        return self.distance - other
402
403    def __rsub__(self, other):
404        return other - self.distance
405
406    def __str__(self):
407        return "{0:.2f}".format(self.distance)
408
409    # Behave like a dict.
410
411    def __getitem__(self, key):
412        """Returns the weighted distance for a named penalty.
413        """
414        dist = sum(self._penalties[key]) * self._weights[key]
415        dist_max = self.max_distance
416        if dist_max:
417            return dist / dist_max
418        return 0.0
419
420    def __iter__(self):
421        return iter(self.items())
422
423    def __len__(self):
424        return len(self.items())
425
426    def keys(self):
427        return [key for key, _ in self.items()]
428
429    def update(self, dist):
430        """Adds all the distance penalties from `dist`.
431        """
432        if not isinstance(dist, Distance):
433            raise ValueError(
434                u'`dist` must be a Distance object, not {0}'.format(type(dist))
435            )
436        for key, penalties in dist._penalties.items():
437            self._penalties.setdefault(key, []).extend(penalties)
438
439    # Adding components.
440
441    def _eq(self, value1, value2):
442        """Returns True if `value1` is equal to `value2`. `value1` may
443        be a compiled regular expression, in which case it will be
444        matched against `value2`.
445        """
446        if isinstance(value1, Pattern):
447            return bool(value1.match(value2))
448        return value1 == value2
449
450    def add(self, key, dist):
451        """Adds a distance penalty. `key` must correspond with a
452        configured weight setting. `dist` must be a float between 0.0
453        and 1.0, and will be added to any existing distance penalties
454        for the same key.
455        """
456        if not 0.0 <= dist <= 1.0:
457            raise ValueError(
458                u'`dist` must be between 0.0 and 1.0, not {0}'.format(dist)
459            )
460        self._penalties.setdefault(key, []).append(dist)
461
462    def add_equality(self, key, value, options):
463        """Adds a distance penalty of 1.0 if `value` doesn't match any
464        of the values in `options`. If an option is a compiled regular
465        expression, it will be considered equal if it matches against
466        `value`.
467        """
468        if not isinstance(options, (list, tuple)):
469            options = [options]
470        for opt in options:
471            if self._eq(opt, value):
472                dist = 0.0
473                break
474        else:
475            dist = 1.0
476        self.add(key, dist)
477
478    def add_expr(self, key, expr):
479        """Adds a distance penalty of 1.0 if `expr` evaluates to True,
480        or 0.0.
481        """
482        if expr:
483            self.add(key, 1.0)
484        else:
485            self.add(key, 0.0)
486
487    def add_number(self, key, number1, number2):
488        """Adds a distance penalty of 1.0 for each number of difference
489        between `number1` and `number2`, or 0.0 when there is no
490        difference. Use this when there is no upper limit on the
491        difference between the two numbers.
492        """
493        diff = abs(number1 - number2)
494        if diff:
495            for i in range(diff):
496                self.add(key, 1.0)
497        else:
498            self.add(key, 0.0)
499
500    def add_priority(self, key, value, options):
501        """Adds a distance penalty that corresponds to the position at
502        which `value` appears in `options`. A distance penalty of 0.0
503        for the first option, or 1.0 if there is no matching option. If
504        an option is a compiled regular expression, it will be
505        considered equal if it matches against `value`.
506        """
507        if not isinstance(options, (list, tuple)):
508            options = [options]
509        unit = 1.0 / (len(options) or 1)
510        for i, opt in enumerate(options):
511            if self._eq(opt, value):
512                dist = i * unit
513                break
514        else:
515            dist = 1.0
516        self.add(key, dist)
517
518    def add_ratio(self, key, number1, number2):
519        """Adds a distance penalty for `number1` as a ratio of `number2`.
520        `number1` is bound at 0 and `number2`.
521        """
522        number = float(max(min(number1, number2), 0))
523        if number2:
524            dist = number / number2
525        else:
526            dist = 0.0
527        self.add(key, dist)
528
529    def add_string(self, key, str1, str2):
530        """Adds a distance penalty based on the edit distance between
531        `str1` and `str2`.
532        """
533        dist = string_dist(str1, str2)
534        self.add(key, dist)
535
536
537# Structures that compose all the information for a candidate match.
538
539AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping',
540                                       'extra_items', 'extra_tracks'])
541
542TrackMatch = namedtuple('TrackMatch', ['distance', 'info'])
543
544
545# Aggregation of sources.
546
547def album_for_mbid(release_id):
548    """Get an AlbumInfo object for a MusicBrainz release ID. Return None
549    if the ID is not found.
550    """
551    try:
552        album = mb.album_for_id(release_id)
553        if album:
554            plugins.send(u'albuminfo_received', info=album)
555        return album
556    except mb.MusicBrainzAPIError as exc:
557        exc.log(log)
558
559
560def track_for_mbid(recording_id):
561    """Get a TrackInfo object for a MusicBrainz recording ID. Return None
562    if the ID is not found.
563    """
564    try:
565        track = mb.track_for_id(recording_id)
566        if track:
567            plugins.send(u'trackinfo_received', info=track)
568        return track
569    except mb.MusicBrainzAPIError as exc:
570        exc.log(log)
571
572
573def albums_for_id(album_id):
574    """Get a list of albums for an ID."""
575    a = album_for_mbid(album_id)
576    if a:
577        yield a
578    for a in plugins.album_for_id(album_id):
579        if a:
580            plugins.send(u'albuminfo_received', info=a)
581            yield a
582
583
584def tracks_for_id(track_id):
585    """Get a list of tracks for an ID."""
586    t = track_for_mbid(track_id)
587    if t:
588        yield t
589    for t in plugins.track_for_id(track_id):
590        if t:
591            plugins.send(u'trackinfo_received', info=t)
592            yield t
593
594
595@plugins.notify_info_yielded(u'albuminfo_received')
596def album_candidates(items, artist, album, va_likely):
597    """Search for album matches. ``items`` is a list of Item objects
598    that make up the album. ``artist`` and ``album`` are the respective
599    names (strings), which may be derived from the item list or may be
600    entered by the user. ``va_likely`` is a boolean indicating whether
601    the album is likely to be a "various artists" release.
602    """
603    # Base candidates if we have album and artist to match.
604    if artist and album:
605        try:
606            for candidate in mb.match_album(artist, album, len(items)):
607                yield candidate
608        except mb.MusicBrainzAPIError as exc:
609            exc.log(log)
610
611    # Also add VA matches from MusicBrainz where appropriate.
612    if va_likely and album:
613        try:
614            for candidate in mb.match_album(None, album, len(items)):
615                yield candidate
616        except mb.MusicBrainzAPIError as exc:
617            exc.log(log)
618
619    # Candidates from plugins.
620    for candidate in plugins.candidates(items, artist, album, va_likely):
621        yield candidate
622
623
624@plugins.notify_info_yielded(u'trackinfo_received')
625def item_candidates(item, artist, title):
626    """Search for item matches. ``item`` is the Item to be matched.
627    ``artist`` and ``title`` are strings and either reflect the item or
628    are specified by the user.
629    """
630
631    # MusicBrainz candidates.
632    if artist and title:
633        try:
634            for candidate in mb.match_track(artist, title):
635                yield candidate
636        except mb.MusicBrainzAPIError as exc:
637            exc.log(log)
638
639    # Plugin candidates.
640    for candidate in plugins.item_candidates(item, artist, title):
641        yield candidate
642