1# -*- coding: utf-8 -*-
2# ------------------------------------------------------------------------------
3# Name:         features/base.py
4# Purpose:      Feature extractors base classes.
5#
6# Authors:      Christopher Ariza
7#               Michael Scott Cuthbert
8#
9# Copyright:    Copyright © 2011-2017 Michael Scott Cuthbert and the music21 Project
10# License:      BSD, see license.txt
11# ------------------------------------------------------------------------------
12import os
13import pathlib
14import pickle
15import unittest
16
17from collections import Counter
18
19from music21 import common
20from music21 import converter
21from music21 import corpus
22from music21 import exceptions21
23from music21 import note
24from music21 import stream
25from music21 import text
26
27from music21.metadata.bundles import MetadataEntry
28
29from music21 import environment
30_MOD = 'features.base'
31environLocal = environment.Environment(_MOD)
32
33# ------------------------------------------------------------------------------
34
35
36class FeatureException(exceptions21.Music21Exception):
37    pass
38
39
40class Feature:
41    '''
42    An object representation of a feature, capable of presentation in a variety of formats,
43    and returned from FeatureExtractor objects.
44
45    Feature objects are simple. It is FeatureExtractors that store all metadata and processing
46    routines for creating Feature objects.  Normally you wouldn't create one of these yourself.
47
48    >>> myFeature = features.Feature()
49    >>> myFeature.dimensions = 3
50    >>> myFeature.name = 'Random arguments'
51    >>> myFeature.isSequential = True
52
53    This is a continuous Feature so we will set discrete to false.
54
55    >>> myFeature.discrete = False
56
57    The .vector is the most important part of the feature, and it starts out as None.
58
59    >>> myFeature.vector is None
60    True
61
62    Calling .prepareVector() gives it a list of Zeros of the length of dimensions.
63
64    >>> myFeature.prepareVectors()
65
66    >>> myFeature.vector
67    [0, 0, 0]
68
69    Now we can set the vector parts:
70
71    >>> myFeature.vector[0] = 4
72    >>> myFeature.vector[1] = 2
73    >>> myFeature.vector[2] = 1
74
75    It's okay just to assign a new list to .vector itself.
76
77    There is a normalize() method which normalizes the values
78    of a histogram to sum to 1.
79
80    >>> myFeature.normalize()
81    >>> myFeature.vector
82    [0.571..., 0.285..., 0.142...]
83
84    And that's it! FeatureExtractors are much more interesting.
85    '''
86
87    def __init__(self):
88        # these values will be filled by the extractor
89        self.dimensions = None  # number of dimensions
90        # data storage; possibly use numpy array
91        self.vector = None
92
93        # consider not storing this values, as may not be necessary
94        self.name = None  # string name representation
95        self.description = None  # string description
96        self.isSequential = None  # True or False
97        self.discrete = None  # is discrete or continuous
98
99    def _getVectors(self):
100        '''
101        Prepare a vector of appropriate size and return
102        '''
103        return [0] * self.dimensions
104
105    def prepareVectors(self):
106        '''
107        Prepare the vector stored in this feature.
108        '''
109        self.vector = self._getVectors()
110
111    def normalize(self):
112        '''
113        Normalizes the vector so that the sum of its elements is 1.
114        '''
115        s = sum(self.vector)
116        try:
117            scalar = 1.0 / s  # get floating point scalar for speed
118        except ZeroDivisionError:
119            raise FeatureException('cannot normalize zero vector')
120        temp = self._getVectors()
121        for i, v in enumerate(self.vector):
122            temp[i] = v * scalar
123        self.vector = temp
124
125
126# ------------------------------------------------------------------------------
127class FeatureExtractorException(exceptions21.Music21Exception):
128    pass
129
130
131class FeatureExtractor:
132    '''
133    A model of process that extracts a feature from a Music21 Stream.
134    The main public interface is the extract() method.
135
136    The extractor can be passed a Stream or a reference to a DataInstance.
137    All Streams are internally converted to a DataInstance if necessary.
138    Usage of a DataInstance offers significant performance advantages, as common forms of
139    the Stream are cached for easy processing.
140    '''
141
142    def __init__(self, dataOrStream=None, *arguments, **keywords):
143        self.stream = None  # the original Stream, or None
144        self.data = None  # a DataInstance object: use to get data
145        self.setData(dataOrStream)
146
147        self.feature = None  # Feature object that results from processing
148
149        if not hasattr(self, 'name'):
150            self.name = None  # string name representation
151        if not hasattr(self, 'description'):
152            self.description = None  # string description
153        if not hasattr(self, 'isSequential'):
154            self.isSequential = None  # True or False
155        if not hasattr(self, 'dimensions'):
156            self.dimensions = None  # number of dimensions
157        if not hasattr(self, 'discrete'):
158            self.discrete = True  # default
159        if not hasattr(self, 'normalize'):
160            self.normalize = False  # default is no
161
162    def setData(self, dataOrStream):
163        '''
164        Set the data that this FeatureExtractor will process.
165        Either a Stream or a DataInstance object can be provided.
166        '''
167        if dataOrStream is not None:
168            if (hasattr(dataOrStream, 'classes')
169                    and isinstance(dataOrStream, stream.Stream)):
170                # environLocal.printDebug(['creating new DataInstance: this should be a Stream:',
171                #     dataOrStream])
172                # if we are passed a stream, create a DataInstance to
173                # manage the
174                # its data; this is less efficient but is good for testing
175                self.stream = dataOrStream
176                self.data = DataInstance(self.stream)
177            # if a DataInstance, do nothing
178            else:
179                self.stream = None
180                self.data = dataOrStream
181
182    def getAttributeLabels(self):
183        '''Return a list of string in a form that is appropriate for data storage.
184
185
186        >>> fe = features.jSymbolic.AmountOfArpeggiationFeature()
187        >>> fe.getAttributeLabels()
188        ['Amount_of_Arpeggiation']
189
190        >>> fe = features.jSymbolic.FifthsPitchHistogramFeature()
191        >>> fe.getAttributeLabels()
192        ['Fifths_Pitch_Histogram_0', 'Fifths_Pitch_Histogram_1', 'Fifths_Pitch_Histogram_2',
193         'Fifths_Pitch_Histogram_3', 'Fifths_Pitch_Histogram_4', 'Fifths_Pitch_Histogram_5',
194         'Fifths_Pitch_Histogram_6', 'Fifths_Pitch_Histogram_7', 'Fifths_Pitch_Histogram_8',
195         'Fifths_Pitch_Histogram_9', 'Fifths_Pitch_Histogram_10', 'Fifths_Pitch_Histogram_11']
196
197        '''
198        post = []
199        if self.dimensions == 1:
200            post.append(self.name.replace(' ', '_'))
201        else:
202            for i in range(self.dimensions):
203                post.append(f"{self.name.replace(' ', '_')}_{i}")
204        return post
205
206    def fillFeatureAttributes(self, feature=None):
207        '''Fill the attributes of a Feature with the descriptors in the FeatureExtractor.
208        '''
209        # operate on passed-in feature or self.feature
210        if feature is None:
211            feature = self.feature
212        feature.name = self.name
213        feature.description = self.description
214        feature.isSequential = self.isSequential
215        feature.dimensions = self.dimensions
216        feature.discrete = self.discrete
217        return feature
218
219    def prepareFeature(self):
220        '''
221        Prepare a new Feature object for data acquisition.
222
223        >>> s = stream.Stream()
224        >>> fe = features.jSymbolic.InitialTimeSignatureFeature(s)
225        >>> fe.prepareFeature()
226        >>> fe.feature.name
227        'Initial Time Signature'
228        >>> fe.feature.dimensions
229        2
230        >>> fe.feature.vector
231        [0, 0]
232        '''
233        self.feature = Feature()
234        self.fillFeatureAttributes()  # will fill self.feature
235        self.feature.prepareVectors()  # will vector with necessary zeros
236
237    def process(self):
238        '''Do processing necessary, storing result in _feature.
239        '''
240        # do work in subclass, calling on self.data
241        pass
242
243    def extract(self, source=None):
244        '''Extract the feature and return the result.
245        '''
246        if source is not None:
247            self.stream = source
248        # preparing the feature always sets self.feature to a new instance
249        self.prepareFeature()
250        self.process()  # will set Feature object to _feature
251        if self.normalize:
252            self.feature.normalize()
253        return self.feature
254
255    def getBlankFeature(self):
256        '''
257        Return a properly configured plain feature as a place holder
258
259        >>> from music21 import features
260        >>> fe = features.jSymbolic.InitialTimeSignatureFeature()
261        >>> fe.name
262        'Initial Time Signature'
263
264        >>> blankF = fe.getBlankFeature()
265        >>> blankF.vector
266        [0, 0]
267        >>> blankF.name
268        'Initial Time Signature'
269        '''
270        f = Feature()
271        self.fillFeatureAttributes(f)
272        f.prepareVectors()  # will vector with necessary zeros
273        return f
274
275
276# ------------------------------------------------------------------------------
277class StreamForms:
278    '''
279    A dictionary-like wrapper of a Stream, providing
280    numerous representations, generated on-demand, and cached.
281
282    A single StreamForms object can be created for an
283    entire Score, as well as one for each Part and/or Voice.
284
285    A DataSet object manages one or more StreamForms
286    objects, and exposes them to FeatureExtractors for usage.
287
288    The streamObj is stored as self.stream and if "prepared" then
289    the prepared form is stored as .prepared
290
291    A dictionary `.forms` stores various intermediary representations
292    of the stream which is the main power of this routine, making
293    it simple to add additional feature extractors at low additional
294    time cost.
295
296    '''
297
298    def __init__(self, streamObj, prepareStream=True):
299        self.stream = streamObj
300        if self.stream is not None:
301            if prepareStream:
302                self.prepared = self._prepareStream(self.stream)
303            else:
304                self.prepared = self.stream
305        else:
306            self.prepared = None
307
308        # basic data storage is a dictionary
309        self.forms = {}
310
311    def keys(self):
312        # will only return forms that are established
313        return self.forms.keys()
314
315    def _prepareStream(self, streamObj):
316        '''
317        Common routines done on Streams prior to processing. Returns a new Stream
318
319        Currently: runs stripTies.
320        '''
321        # Let stripTies make a copy so that we don't leave side effects on the input stream
322        streamObj = streamObj.stripTies(inPlace=False)
323        return streamObj
324
325    def __getitem__(self, key):
326        '''
327        Get a form of this Stream, using a cached version if available.
328        '''
329        # first, check for cached version
330        if key in self.forms:
331            return self.forms[key]
332
333        splitKeys = key.split('.')
334
335        prepared = self.prepared
336        for i in range(len(splitKeys)):
337            subKey = '.'.join(splitKeys[:i + 1])
338            if subKey in self.forms:
339                continue
340            if i > 0:
341                previousKey = '.'.join(splitKeys[:i])
342                # should always be there.
343                prepared = self.forms[previousKey]
344
345            lastKey = splitKeys[i]
346
347            if lastKey in self.keysToMethods:
348                prepared = self.keysToMethods[lastKey](self, prepared)
349            elif lastKey.startswith('getElementsByClass('):
350                classToGet = lastKey[len('getElementsByClass('):-1]
351                prepared = prepared.getElementsByClass(classToGet)
352            else:
353                raise AttributeError(f'no such attribute: {lastKey} in {key}')
354            self.forms[subKey] = prepared
355
356        return prepared
357
358    def _getIntervalHistogram(self, algorithm='midi'):
359        # note that this does not optimize and cache part presentations
360        histo = [0] * 128
361        # if we have parts, must add one at a time
362        if self.prepared.hasPartLikeStreams():
363            parts = self.prepared.parts
364        else:
365            parts = [self.prepared]  # emulate a list
366        for p in parts:
367            # will be flat
368
369            # noNone means that we will see all connections, even w/ a gap
370            post = p.findConsecutiveNotes(skipRests=True,
371                                          skipChords=True, skipGaps=True, noNone=True)
372            for i, n in enumerate(post):
373                if i < len(post) - 1:  # if not last
374                    iNext = i + 1
375                    nNext = post[iNext]
376                    nValue = getattr(n.pitch, algorithm)
377                    nextValue = getattr(nNext.pitch, algorithm)
378
379                    try:
380                        histo[abs(nValue - nextValue)] += 1
381                    except AttributeError:
382                        pass  # problem with not having midi
383        return histo
384# ----------------------------------------------------------------------------
385
386    def formPartitionByInstrument(self, prepared):
387        from music21 import instrument
388        return instrument.partitionByInstrument(prepared)
389
390    def formSetClassHistogram(self, prepared):
391        return Counter([c.forteClassTnI for c in prepared])
392
393    def formPitchClassSetHistogram(self, prepared):
394        return Counter([c.orderedPitchClassesString for c in prepared])
395
396    def formTypesHistogram(self, prepared):
397        histo = {}
398
399        # keys are methods on Chord
400        keys = ['isTriad', 'isSeventh', 'isMajorTriad', 'isMinorTriad',
401                'isIncompleteMajorTriad', 'isIncompleteMinorTriad', 'isDiminishedTriad',
402                'isAugmentedTriad', 'isDominantSeventh', 'isDiminishedSeventh',
403                'isHalfDiminishedSeventh']
404
405        for c in prepared:
406            for thisKey in keys:
407                if thisKey not in histo:
408                    histo[thisKey] = 0
409                # get the function attr, call it, check bool
410                if getattr(c, thisKey)():
411                    histo[thisKey] += 1
412        return histo
413
414    def formGetElementsByClassMeasure(self, prepared):
415        if isinstance(prepared, stream.Score):
416            post = stream.Stream()
417            for p in prepared.parts:
418                # insert in overlapping offset positions
419                for m in p.getElementsByClass('Measure'):
420                    post.insert(m.getOffsetBySite(p), m)
421        else:
422            post = prepared.getElementsByClass('Measure')
423        return post
424
425    def formChordify(self, prepared):
426        if isinstance(prepared, stream.Score):
427            # options here permit getting part information out
428            # of chordified representation
429            return prepared.chordify(
430                addPartIdAsGroup=True, removeRedundantPitches=False)
431        else:  # for now, just return a normal Part or Stream
432            # this seems wrong -- what if there are multiple voices
433            # in the part?
434            return prepared
435
436    def formQuarterLengthHistogram(self, prepared):
437        return Counter([float(n.quarterLength) for n in prepared])
438
439    def formMidiPitchHistogram(self, pitches):
440        return Counter([p.midi for p in pitches])
441
442    def formPitchClassHistogram(self, pitches):
443        cc = Counter([p.pitchClass for p in pitches])
444        histo = [0] * 12
445        for k in cc:
446            histo[k] = cc[k]
447        return histo
448
449    def formMidiIntervalHistogram(self, unused):
450        return self._getIntervalHistogram('midi')
451
452    def formContourList(self, prepared):
453        # list of all directed half steps
454        cList = []
455        # if we have parts, must add one at a time
456        if prepared.hasPartLikeStreams():
457            parts = prepared.parts
458        else:
459            parts = [prepared]  # emulate a list
460
461        for p in parts:
462            # this may be unnecessary but we cannot accessed cached part data
463
464            # noNone means that we will see all connections, even w/ a gap
465            post = p.findConsecutiveNotes(skipRests=True,
466                                          skipChords=False,
467                                          skipGaps=True,
468                                          noNone=True)
469            for i, n in enumerate(post):
470                if i < (len(post) - 1):  # if not last
471                    iNext = i + 1
472                    nNext = post[iNext]
473
474                    if n.isChord:
475                        ps = n.sortDiatonicAscending().pitches[-1].midi
476                    else:  # normal note
477                        ps = n.pitch.midi
478                    if nNext.isChord:
479                        psNext = nNext.sortDiatonicAscending().pitches[-1].midi
480                    else:  # normal note
481                        psNext = nNext.pitch.midi
482
483                    cList.append(psNext - ps)
484        # environLocal.printDebug(['contourList', cList])
485        return cList
486
487    def formSecondsMap(self, prepared):
488        post = []
489        secondsMap = prepared.secondsMap
490        # filter only notes; all elements would otherwise be gathered
491        for bundle in secondsMap:
492            if isinstance(bundle['element'], note.NotRest):
493                post.append(bundle)
494        return post
495
496    def formBeatHistogram(self, secondsMap):
497        secondsList = [d['durationSeconds'] for d in secondsMap]
498        bpmList = [round(60.0 / d) for d in secondsList]
499        histogram = [0] * 200
500        for thisBPM in bpmList:
501            if thisBPM < 40 or thisBPM > 200:
502                continue
503            histogramIndex = int(thisBPM)
504            histogram[histogramIndex] += 1
505        return histogram
506
507    keysToMethods = {
508        'flat': lambda unused, p: p.flatten(),
509        'pitches': lambda unused, p: p.pitches,
510        'notes': lambda unused, p: p.notes,
511        'getElementsByClass(Measure)': formGetElementsByClassMeasure,
512        'metronomeMarkBoundaries': lambda unused, p: p.metronomeMarkBoundaries(),
513        'chordify': formChordify,
514        'partitionByInstrument': formPartitionByInstrument,
515        'setClassHistogram': formSetClassHistogram,
516        'pitchClassHistogram': formPitchClassHistogram,
517        'typesHistogram': formTypesHistogram,
518        'quarterLengthHistogram': formQuarterLengthHistogram,
519        'pitchClassSetHistogram': formPitchClassSetHistogram,
520        'midiPitchHistogram': formMidiPitchHistogram,
521        'midiIntervalHistogram': formMidiIntervalHistogram,
522        'contourList': formContourList,
523        'analyzedKey': lambda unused, f: f.analyze(method='key'),
524        'tonalCertainty': lambda unused, foundKey: foundKey.tonalCertainty(),
525        'metadata': lambda unused, p: p.metadata,
526        'secondsMap': formSecondsMap,
527        'assembledLyrics': lambda unused, p: text.assembleLyrics(p),
528        'beatHistogram': formBeatHistogram,
529    }
530
531# ------------------------------------------------------------------------------
532
533
534class DataInstance:
535    '''
536    A data instance for analysis. This object prepares a Stream
537    (by stripping ties, etc.) and stores
538    multiple commonly-used stream representations once, providing rapid processing.
539    '''
540    # pylint: disable=redefined-builtin
541    def __init__(self, streamOrPath=None, id=None):  # @ReservedAssignment
542        if isinstance(streamOrPath, stream.Stream):
543            self.stream = streamOrPath
544            self.streamPath = None
545        else:
546            self.stream = None
547            self.streamPath = streamOrPath
548
549        # store an id for the source stream: file path url, corpus url
550        # or metadata title
551        if id is not None:
552            self._id = id
553        elif (self.stream is not None
554              and hasattr(self.stream, 'metadata')
555              and self.stream.metadata is not None
556              and self.stream.metadata.title is not None
557              ):
558            self._id = self.stream.metadata.title
559        elif self.stream is not None and hasattr(self.stream, 'sourcePath'):
560            self._id = self.stream.sourcePath
561        elif self.streamPath is not None:
562            if hasattr(self.streamPath, 'sourcePath'):
563                self._id = str(self.streamPath.sourcePath)
564            else:
565                self._id = str(self.streamPath)
566        else:
567            self._id = ''
568
569        # the attribute name in the data set for this label
570        self.classLabel = None
571        # store the class value for this data instance
572        self._classValue = None
573
574        self.forms = None
575
576        # store a list of voices, extracted from each part,
577        self.formsByVoice = []
578        # if parts exist, store a forms for each
579        self.formsByPart = []
580
581        self.featureExtractorClassesForParallelRunning = []
582
583        if self.stream is not None:
584            self.setupPostStreamParse()
585
586    def setupPostStreamParse(self):
587        '''
588        Setup the StreamForms objects and other things that
589        need to be done after a Stream is passed in but before
590        feature extracting is run.
591
592        Run automatically at instantiation if a Stream is passed in.
593        '''
594        # perform basic operations that are performed on all
595        # streams
596
597        # store a dictionary of StreamForms
598        self.forms = StreamForms(self.stream)
599
600        # if parts exist, store a forms for each
601        self.formsByPart = []
602        if hasattr(self.stream, 'parts'):
603            self.partsCount = len(self.stream.parts)
604            for p in self.stream.parts:
605                # note that this will join ties and expand rests again
606                self.formsByPart.append(StreamForms(p))
607        else:
608            self.partsCount = 0
609
610        for v in self.stream.recurse().getElementsByClass('Voice'):
611            self.formsByPart.append(StreamForms(v))
612
613    def setClassLabel(self, classLabel, classValue=None):
614        '''
615        Set the class label, as well as the class value if known.
616        The class label is the attribute name used to define the class of this data instance.
617
618        >>> #_DOCS_SHOW s = corpus.parse('bwv66.6')
619        >>> s = stream.Stream() #_DOCS_HIDE
620        >>> di = features.DataInstance(s)
621        >>> di.setClassLabel('Composer', 'Bach')
622        '''
623        self.classLabel = classLabel
624        self._classValue = classValue
625
626    def getClassValue(self):
627        if self._classValue is None or callable(self._classValue) and self.stream is None:
628            return ''
629
630        if callable(self._classValue) and self.stream is not None:
631            self._classValue = self._classValue(self.stream)
632
633        return self._classValue
634
635    def getId(self):
636        if self._id is None or callable(self._id) and self.stream is None:
637            return ''
638
639        if callable(self._id) and self.stream is not None:
640            self._id = self._id(self.stream)
641
642        # make sure there are no spaces
643        try:
644            return self._id.replace(' ', '_')
645        except AttributeError as e:
646            raise AttributeError(str(self._id)) from e
647
648    def parseStream(self):
649        '''
650        If a path to a Stream has been passed in at creation,
651        then this will parse it (whether it's a corpus string,
652        a converter string (url or filepath), a pathlib.Path,
653        or a metadata.bundles.MetadataEntry.
654        '''
655        if self.stream is not None:
656            return
657
658        if isinstance(self.streamPath, str):
659            # could be corpus or file path
660            if os.path.exists(self.streamPath) or self.streamPath.startswith('http'):
661                s = converter.parse(self.streamPath)
662            else:  # assume corpus
663                s = corpus.parse(self.streamPath)
664        elif isinstance(self.streamPath, pathlib.Path):
665            # could be corpus or file path
666            if self.streamPath.exists():
667                s = converter.parse(self.streamPath)
668            else:  # assume corpus
669                s = corpus.parse(self.streamPath)
670        elif isinstance(self.streamPath, MetadataEntry):
671            s = self.streamPath.parse()
672        else:
673            raise ValueError(f'Invalid streamPath type: {type(self.streamPath)}')
674
675        self.stream = s
676        self.setupPostStreamParse()
677
678    def __getitem__(self, key):
679        '''
680        Get a form of this Stream, using a cached version if available.
681
682        >>> di = features.DataInstance('bach/bwv66.6')
683        >>> len(di['flat'])
684        193
685        >>> len(di['flat.pitches'])
686        163
687        >>> len(di['flat.notes'])
688        163
689        >>> len(di['getElementsByClass(Measure)'])
690        40
691        >>> len(di['flat.getElementsByClass(TimeSignature)'])
692        4
693        '''
694        self.parseStream()
695        if key in ['parts']:
696            # return a list of Forms for each part
697            return self.formsByPart
698        elif key in ['voices']:
699            # return a list of Forms for voices
700            return self.formsByVoice
701        # try to create by calling the attribute
702        # will raise an attribute error if there is a problem
703        return self.forms[key]
704
705
706# ------------------------------------------------------------------------------
707class DataSetException(exceptions21.Music21Exception):
708    pass
709
710
711class DataSet:
712    '''
713    A set of features, as well as a collection of data to operate on.
714
715    Comprises multiple DataInstance objects, a FeatureSet, and an OutputFormat.
716
717
718    >>> ds = features.DataSet(classLabel='Composer')
719    >>> f = [features.jSymbolic.PitchClassDistributionFeature,
720    ...      features.jSymbolic.ChangesOfMeterFeature,
721    ...      features.jSymbolic.InitialTimeSignatureFeature]
722    >>> ds.addFeatureExtractors(f)
723    >>> ds.addData('bwv66.6', classValue='Bach')
724    >>> ds.addData('bach/bwv324.xml', classValue='Bach')
725    >>> ds.process()
726    >>> ds.getFeaturesAsList()[0]
727    ['bwv66.6', 0.196..., 0.0736..., 0.006..., 0.098..., 0.0368..., 0.177..., 0.0,
728     0.085..., 0.134..., 0.018..., 0.171..., 0.0, 0, 4, 4, 'Bach']
729    >>> ds.getFeaturesAsList()[1]
730    ['bach/bwv324.xml', 0.25, 0.0288..., 0.125, 0.0, 0.144..., 0.125, 0.0, 0.163..., 0.0, 0.134...,
731    0.0288..., 0.0, 0, 4, 4, 'Bach']
732
733    >>> ds = ds.getString()
734
735
736    By default, all exceptions are caught and printed if debug mode is on.
737
738    Set ds.failFast = True to not catch them.
739
740    Set ds.quiet = False to print them regardless of debug mode.
741    '''
742
743    def __init__(self, classLabel=None, featureExtractors=()):
744        # assume a two dimensional array
745        self.dataInstances = []
746
747        # order of feature extractors is the order used in the presentations
748        self._featureExtractors = []
749        self._instantiatedFeatureExtractors = []
750        # the label of the class
751        self._classLabel = classLabel
752        # store a multidimensional storage of all features
753        self.features = []
754
755        self.failFast = False
756        self.quiet = True
757
758        self.runParallel = True
759        # set extractors
760        self.addFeatureExtractors(featureExtractors)
761
762    def getClassLabel(self):
763        return self._classLabel
764
765    def addFeatureExtractors(self, values):
766        '''
767        Add one or more FeatureExtractor objects, either as a list or as an individual object.
768        '''
769        # features are instantiated here
770        # however, they do not have a data assignment
771        if not common.isIterable(values):
772            values = [values]
773        # need to create instances
774        for sub in values:
775            self._featureExtractors.append(sub)
776            self._instantiatedFeatureExtractors.append(sub())
777
778    def getAttributeLabels(self, includeClassLabel=True,
779                           includeId=True):
780        '''
781        Return a list of all attribute labels. Optionally add a class
782        label field and/or an id field.
783
784
785        >>> f = [features.jSymbolic.PitchClassDistributionFeature,
786        ...      features.jSymbolic.ChangesOfMeterFeature]
787        >>> ds = features.DataSet(classLabel='Composer', featureExtractors=f)
788        >>> ds.getAttributeLabels(includeId=False)
789        ['Pitch_Class_Distribution_0',
790         'Pitch_Class_Distribution_1',
791         ...
792         ...
793         'Pitch_Class_Distribution_11',
794         'Changes_of_Meter',
795         'Composer']
796        '''
797        post = []
798        # place ids first
799        if includeId:
800            post.append('Identifier')
801        for fe in self._instantiatedFeatureExtractors:
802            post += fe.getAttributeLabels()
803        if self._classLabel is not None and includeClassLabel:
804            post.append(self._classLabel.replace(' ', '_'))
805        return post
806
807    def getDiscreteLabels(self, includeClassLabel=True, includeId=True):
808        '''
809        Return column labels for discrete status.
810
811        >>> f = [features.jSymbolic.PitchClassDistributionFeature,
812        ...      features.jSymbolic.ChangesOfMeterFeature]
813        >>> ds = features.DataSet(classLabel='Composer', featureExtractors=f)
814        >>> ds.getDiscreteLabels()
815        [None, False, False, False, False, False, False, False, False, False,
816         False, False, False, True, True]
817        '''
818        post = []
819        if includeId:
820            post.append(None)  # just a spacer
821        for fe in self._instantiatedFeatureExtractors:
822            # need as many statements of discrete as there are dimensions
823            post += [fe.discrete] * fe.dimensions
824        # class label is assumed always discrete
825        if self._classLabel is not None and includeClassLabel:
826            post.append(True)
827        return post
828
829    def getClassPositionLabels(self, includeId=True):
830        '''
831        Return column labels for the presence of a class definition
832
833        >>> f = [features.jSymbolic.PitchClassDistributionFeature,
834        ...      features.jSymbolic.ChangesOfMeterFeature]
835        >>> ds = features.DataSet(classLabel='Composer', featureExtractors=f)
836        >>> ds.getClassPositionLabels()
837        [None, False, False, False, False, False, False, False, False,
838         False, False, False, False, False, True]
839        '''
840        post = []
841        if includeId:
842            post.append(None)  # just a spacer
843        for fe in self._instantiatedFeatureExtractors:
844            # need as many statements of discrete as there are dimensions
845            post += [False] * fe.dimensions
846        # class label is assumed always discrete
847        if self._classLabel is not None:
848            post.append(True)
849        return post
850
851    def addMultipleData(self, dataList, classValues, ids=None):
852        '''
853        add multiple data points at the same time.
854
855        Requires an iterable (including MetadataBundle) for dataList holding
856        types that can be passed to addData, and an equally sized list of dataValues
857        and an equally sized list of ids (or None)
858
859        classValues can also be a pickleable function that will be called on
860        each instance after parsing, as can ids.
861        '''
862        if (not callable(classValues)
863                and len(dataList) != len(classValues)):
864            raise DataSetException(
865                'If classValues is not a function, it must have the same length as dataList')
866        if (ids is not None
867                and not callable(ids)
868                and len(dataList) != len(ids)):
869            raise DataSetException(
870                'If ids is not a function or None, it must have the same length as dataList')
871
872        if callable(classValues):
873            try:
874                pickle.dumps(classValues)
875            except pickle.PicklingError:
876                raise DataSetException('classValues if a function must be pickleable. '
877                                       + 'Lambda and some other functions are not.')
878
879            classValues = [classValues] * len(dataList)
880
881        if callable(ids):
882            try:
883                pickle.dumps(ids)
884            except pickle.PicklingError:
885                raise DataSetException('ids if a function must be pickleable. '
886                                       + 'Lambda and some other functions are not.')
887
888            ids = [ids] * len(dataList)
889        elif ids is None:
890            ids = [None] * len(dataList)
891
892        for i in range(len(dataList)):
893            d = dataList[i]
894            cv = classValues[i]
895            thisId = ids[i]
896            self.addData(d, cv, thisId)
897
898    # pylint: disable=redefined-builtin
899    def addData(self, dataOrStreamOrPath, classValue=None, id=None):  # @ReservedAssignment
900        '''
901        Add a Stream, DataInstance, MetadataEntry, or path (Posix or str)
902        to a corpus or local file to this data set.
903
904        The class value passed here is assumed to be the same as
905        the classLabel assigned at startup.
906        '''
907        if self._classLabel is None:
908            raise DataSetException(
909                'cannot add data unless a class label for this DataSet has been set.')
910
911        s = None
912        if isinstance(dataOrStreamOrPath, DataInstance):
913            di = dataOrStreamOrPath
914            s = di.stream
915            if s is None:
916                s = di.streamPath
917        else:
918            # all else are stored directly
919            s = dataOrStreamOrPath
920            di = DataInstance(dataOrStreamOrPath, id=id)
921
922        di.setClassLabel(self._classLabel, classValue)
923        self.dataInstances.append(di)
924
925    def process(self):
926        '''
927        Process all Data with all FeatureExtractors.
928        Processed data is stored internally as numerous Feature objects.
929        '''
930        if self.runParallel:
931            return self._processParallel()
932        else:
933            return self._processNonParallel()
934
935    def _processParallel(self):
936        '''
937        Run a set of processes in parallel.
938        '''
939        for di in self.dataInstances:
940            di.featureExtractorClassesForParallelRunning = self._featureExtractors
941
942        shouldUpdate = not self.quiet
943
944        # print('about to run parallel')
945        outputData = common.runParallel([(di, self.failFast) for di in self.dataInstances],
946                                           _dataSetParallelSubprocess,
947                                           updateFunction=shouldUpdate,
948                                           updateMultiply=1,
949                                           unpackIterable=True
950                                        )
951        featureData, errors, classValues, ids = zip(*outputData)
952        errors = common.flattenList(errors)
953        for e in errors:
954            if self.quiet is True:
955                environLocal.printDebug(e)
956            else:
957                environLocal.warn(e)
958        self.features = featureData
959
960        for i, di in enumerate(self.dataInstances):
961            if callable(di._classValue):
962                di._classValue = classValues[i]
963            if callable(di._id):
964                di._id = ids[i]
965
966    def _processNonParallel(self):
967        '''
968        The traditional way: run non-parallel
969        '''
970        # clear features
971        self.features = []
972        for data in self.dataInstances:
973            row = []
974            for fe in self._instantiatedFeatureExtractors:
975                fe.setData(data)
976                # in some cases there might be problem; to not fail
977                try:
978                    fReturned = fe.extract()
979                except Exception as e:  # pylint: disable=broad-except
980                    # for now take any error
981                    fList = ['failed feature extractor:', fe, str(e)]
982                    if self.quiet is True:
983                        environLocal.printDebug(fList)
984                    else:
985                        environLocal.warn(fList)
986                    if self.failFast is True:
987                        raise e
988                    # provide a blank feature extractor
989                    fReturned = fe.getBlankFeature()
990
991                row.append(fReturned)  # get feature and store
992            # rows will align with data the order of DataInstances
993            self.features.append(row)
994
995    def getFeaturesAsList(self, includeClassLabel=True, includeId=True, concatenateLists=True):
996        '''
997        Get processed data as a list of lists, merging any sub-lists
998        in multi-dimensional features.
999        '''
1000        post = []
1001        for i, row in enumerate(self.features):
1002            v = []
1003            di = self.dataInstances[i]
1004
1005            if includeId:
1006                v.append(di.getId())
1007
1008            for f in row:
1009                if concatenateLists:
1010                    v += f.vector
1011                else:
1012                    v.append(f.vector)
1013            if includeClassLabel:
1014                v.append(di.getClassValue())
1015            post.append(v)
1016        if not includeClassLabel and not includeId:
1017            return post[0]
1018        else:
1019            return post
1020
1021    def getUniqueClassValues(self):
1022        '''
1023        Return a list of unique class values.
1024        '''
1025        post = []
1026        for di in self.dataInstances:
1027            v = di.getClassValue()
1028            if v not in post:
1029                post.append(v)
1030        return post
1031
1032    def _getOutputFormat(self, featureFormat):
1033        from music21.features import outputFormats
1034        if featureFormat.lower() in ['tab', 'orange', 'taborange', None]:
1035            outputFormat = outputFormats.OutputTabOrange(dataSet=self)
1036        elif featureFormat.lower() in ['csv', 'comma']:
1037            outputFormat = outputFormats.OutputCSV(dataSet=self)
1038        elif featureFormat.lower() in ['arff', 'attribute']:
1039            outputFormat = outputFormats.OutputARFF(dataSet=self)
1040        else:
1041            return None
1042        return outputFormat
1043
1044    def _getOutputFormatFromFilePath(self, fp):
1045        '''
1046        Get an output format from a file path if possible, otherwise return None.
1047
1048        >>> ds = features.DataSet()
1049        >>> ds._getOutputFormatFromFilePath('test.tab')
1050        <music21.features.outputFormats.OutputTabOrange object at ...>
1051        >>> ds._getOutputFormatFromFilePath('test.csv')
1052        <music21.features.outputFormats.OutputCSV object at ...>
1053        >>> ds._getOutputFormatFromFilePath('junk') is None
1054        True
1055        '''
1056        # get format from fp if possible
1057        of = None
1058        if '.' in fp:
1059            if self._getOutputFormat(fp.split('.')[-1]) is not None:
1060                of = self._getOutputFormat(fp.split('.')[-1])
1061        return of
1062
1063    def getString(self, outputFmt='tab'):
1064        '''
1065        Get a string representation of the data set in a specific format.
1066        '''
1067        # pass reference to self to output
1068        outputFormat = self._getOutputFormat(outputFmt)
1069        return outputFormat.getString()
1070
1071    # pylint: disable=redefined-builtin
1072    def write(self, fp=None, format=None, includeClassLabel=True):  # @ReservedAssignment
1073        '''
1074        Set the output format object.
1075        '''
1076        if format is None and fp is not None:
1077            outputFormat = self._getOutputFormatFromFilePath(fp)
1078        else:
1079            outputFormat = self._getOutputFormat(format)
1080        if outputFormat is None:
1081            raise DataSetException('no output format could be defined from file path '
1082                                   + f'{fp} or format {format}')
1083
1084        return outputFormat.write(fp=fp, includeClassLabel=includeClassLabel)
1085
1086
1087def _dataSetParallelSubprocess(dataInstance, failFast):
1088    row = []
1089    errors = []
1090    # howBigWeCopied = len(pickle.dumps(dataInstance))
1091    # print('Starting ', dataInstance, ' Size: ', howBigWeCopied)
1092    for feClass in dataInstance.featureExtractorClassesForParallelRunning:
1093        fe = feClass()
1094        fe.setData(dataInstance)
1095        # in some cases there might be problem; to not fail
1096        try:
1097            fReturned = fe.extract()
1098        except Exception as e:  # pylint: disable=broad-except
1099            # for now take any error
1100            errors.append('failed feature extractor:' + str(fe) + ': ' + str(e))
1101            if failFast:
1102                raise e
1103            # provide a blank feature extractor
1104            fReturned = fe.getBlankFeature()
1105
1106        row.append(fReturned)  # get feature and store
1107    # rows will align with data the order of DataInstances
1108    return row, errors, dataInstance.getClassValue(), dataInstance.getId()
1109
1110
1111def allFeaturesAsList(streamInput):
1112    '''
1113    returns a list containing ALL currently implemented feature extractors
1114
1115    streamInput can be a Stream, DataInstance, or path to a corpus or local
1116    file to this data set.
1117
1118    >>> s = converter.parse('tinynotation: 4/4 c4 d e2')
1119    >>> f = features.allFeaturesAsList(s)
1120    >>> f[2:5]
1121    [[2], [2], [1.0]]
1122    >>> len(f) > 85
1123    True
1124    '''
1125    from music21.features import jSymbolic, native
1126    ds = DataSet(classLabel='')
1127    f = list(jSymbolic.featureExtractors) + list(native.featureExtractors)
1128    ds.addFeatureExtractors(f)
1129    ds.addData(streamInput)
1130    ds.process()
1131    allData = ds.getFeaturesAsList(includeClassLabel=False,
1132                                   includeId=False,
1133                                   concatenateLists=False)
1134
1135    return allData
1136
1137
1138# ------------------------------------------------------------------------------
1139def extractorsById(idOrList, library=('jSymbolic', 'native')):
1140    '''
1141    Given one or more :class:`~music21.features.FeatureExtractor` ids, return the
1142    appropriate  subclass. An optional `library` argument can be added to define which
1143    module is used. Current options are jSymbolic and native.
1144
1145    >>> features.extractorsById('p20')
1146    [<class 'music21.features.jSymbolic.PitchClassDistributionFeature'>]
1147
1148    >>> [x.id for x in features.extractorsById('p20')]
1149    ['P20']
1150    >>> [x.id for x in features.extractorsById(['p19', 'p20'])]
1151    ['P19', 'P20']
1152
1153
1154    Normalizes case...
1155
1156    >>> [x.id for x in features.extractorsById(['r31', 'r32', 'r33', 'r34', 'r35', 'p1', 'p2'])]
1157    ['R31', 'R32', 'R33', 'R34', 'R35', 'P1', 'P2']
1158
1159    Get all feature extractors from all libraries
1160
1161    >>> y = [x.id for x in features.extractorsById('all')]
1162    >>> y[0:3], y[-3:-1]
1163    (['M1', 'M2', 'M3'], ['CS12', 'MC1'])
1164
1165    '''
1166    from music21.features import jSymbolic
1167    from music21.features import native
1168
1169    if not common.isIterable(library):
1170        library = [library]
1171
1172    featureExtractors = []
1173    for lib in library:
1174        if lib.lower() in ['jsymbolic', 'all']:
1175            featureExtractors += jSymbolic.featureExtractors
1176        elif lib.lower() in ['native', 'all']:
1177            featureExtractors += native.featureExtractors
1178
1179    if not common.isIterable(idOrList):
1180        idOrList = [idOrList]
1181
1182    flatIds = []
1183    for featureId in idOrList:
1184        featureId = featureId.strip().lower()
1185        featureId.replace('-', '')
1186        featureId.replace(' ', '')
1187        flatIds.append(featureId)
1188
1189    post = []
1190    if not flatIds:
1191        return post
1192
1193    for fe in featureExtractors:
1194        if fe.id.lower() in flatIds or flatIds[0].lower() == 'all':
1195            post.append(fe)
1196    return post
1197
1198
1199def extractorById(idOrList, library=('jSymbolic', 'native')):
1200    '''
1201    Get the first feature matched by extractorsById().
1202
1203    >>> s = stream.Stream()
1204    >>> s.append(note.Note('A4'))
1205    >>> fe = features.extractorById('p20')(s)  # call class
1206    >>> fe.extract().vector
1207    [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1208
1209    '''
1210    ebi = extractorsById(idOrList=idOrList, library=library)
1211    if ebi:
1212        return ebi[0]
1213    return None  # no match
1214
1215
1216def vectorById(streamObj, vectorId, library=('jSymbolic', 'native')):
1217    '''
1218    Utility function to get a vector from an extractor
1219
1220    >>> s = stream.Stream()
1221    >>> s.append(note.Note('A4'))
1222    >>> features.vectorById(s, 'p20')
1223    [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1224    '''
1225    fe = extractorById(vectorId, library)(streamObj)  # call class with stream
1226    if fe is None:
1227        return None  # could raise exception
1228    return fe.extract().vector
1229
1230
1231def getIndex(featureString, extractorType=None):
1232    '''
1233    Returns the list index of the given feature extractor and the feature extractor
1234    category (jsymbolic or native). If feature extractor string is not in either
1235    jsymbolic or native feature extractors, returns None
1236
1237    optionally include the extractorType ('jsymbolic' or 'native') if known
1238    and searching will be made more efficient
1239
1240
1241    >>> features.getIndex('Range')
1242    (61, 'jsymbolic')
1243    >>> features.getIndex('Ends With Landini Melodic Contour')
1244    (18, 'native')
1245    >>> features.getIndex('aBrandNewFeature!') is None
1246    True
1247    >>> features.getIndex('Fifths Pitch Histogram', 'jsymbolic')
1248    (70, 'jsymbolic')
1249    >>> features.getIndex('Tonal Certainty', 'native')
1250    (1, 'native')
1251    '''
1252    from music21.features import jSymbolic, native
1253
1254    if extractorType is None or extractorType == 'jsymbolic':
1255        indexCnt = 0
1256        for feature in jSymbolic.featureExtractors:
1257
1258            if feature().name == featureString:
1259                return (indexCnt, 'jsymbolic')
1260            indexCnt += 1
1261    if extractorType is None or extractorType == 'native':
1262        indexCnt = 0
1263        for feature in native.featureExtractors:
1264            if feature().name == featureString:
1265                return (indexCnt, 'native')
1266            indexCnt += 1
1267
1268        return None
1269
1270
1271# ------------------------------------------------------------------------------
1272class Test(unittest.TestCase):
1273
1274    def testStreamFormsA(self):
1275
1276        from music21 import features
1277        self.maxDiff = None
1278
1279        s = corpus.parse('corelli/opus3no1/1grave')
1280        # s.chordify().show()
1281        di = features.DataInstance(s)
1282        self.assertEqual(len(di['flat']), 291)
1283        self.assertEqual(len(di['flat.notes']), 238)
1284
1285        # di['chordify'].show('t')
1286        self.assertEqual(len(di['chordify']), 27)
1287        chordifiedChords = di['chordify.flat.getElementsByClass(Chord)']
1288        self.assertEqual(len(chordifiedChords), 145)
1289        histo = di['chordify.flat.getElementsByClass(Chord).setClassHistogram']
1290        # print(histo)
1291
1292        self.assertEqual(histo,
1293                         {'3-11': 30, '2-4': 26, '1-1': 25, '2-3': 16, '3-9': 12, '2-2': 6,
1294                          '3-7': 6, '2-5': 6, '3-4': 5, '3-6': 5, '3-10': 4,
1295                          '3-8': 2, '3-2': 2})
1296
1297        self.assertEqual(di['chordify.flat.getElementsByClass(Chord).typesHistogram'],
1298                           {'isMinorTriad': 6, 'isAugmentedTriad': 0,
1299                            'isTriad': 34, 'isSeventh': 0, 'isDiminishedTriad': 4,
1300                            'isDiminishedSeventh': 0, 'isIncompleteMajorTriad': 26,
1301                            'isHalfDiminishedSeventh': 0, 'isMajorTriad': 24,
1302                            'isDominantSeventh': 0, 'isIncompleteMinorTriad': 16})
1303
1304        self.assertEqual(di['flat.notes.quarterLengthHistogram'],
1305                         {0.5: 116, 1.0: 39, 1.5: 27, 2.0: 31, 3.0: 2, 4.0: 3,
1306                          0.75: 4, 0.25: 16})
1307
1308        # can access parts by index
1309        self.assertEqual(len(di['parts']), 3)
1310        # stored in parts are StreamForms instances, caching their results
1311        self.assertEqual(len(di['parts'][0]['flat.notes']), 71)
1312        self.assertEqual(len(di['parts'][1]['flat.notes']), 66)
1313
1314        # getting a measure by part
1315        self.assertEqual(len(di['parts'][0]['getElementsByClass(Measure)']), 19)
1316        self.assertEqual(len(di['parts'][1]['getElementsByClass(Measure)']), 19)
1317
1318        self.assertEqual(di['parts'][0]['pitches.pitchClassHistogram'],
1319                         [9, 1, 11, 0, 9, 13, 0, 11, 0, 12, 5, 0])
1320        # the sum of the two arrays is the pitch class histogram of the complete
1321        # work
1322        self.assertEqual(di['pitches.pitchClassHistogram'],
1323                         [47, 2, 25, 0, 25, 42, 0, 33, 0, 38, 22, 4])
1324
1325    def testStreamFormsB(self):
1326        from music21 import features
1327
1328        s = stream.Stream()
1329        for p in ['c4', 'c4', 'd-4', 'd#4', 'f#4', 'a#4', 'd#5', 'a5', 'a5']:
1330            s.append(note.Note(p))
1331        di = features.DataInstance(s)
1332        self.assertEqual(di['pitches.midiIntervalHistogram'],
1333                         [2, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1334                          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1335                          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1336                          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1337                          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1338                          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1339                          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
1340
1341    def testStreamFormsC(self):
1342        from pprint import pformat
1343        from music21 import features
1344
1345        s = stream.Stream()
1346        for p in ['c4', 'c4', 'd-4', 'd#4', 'f#4', 'a#4', 'd#5', 'a5']:
1347            s.append(note.Note(p))
1348        di = features.DataInstance(s)
1349
1350        self.assertEqual(pformat(di['flat.secondsMap']), '''[{'durationSeconds': 0.5,
1351  'element': <music21.note.Note C>,
1352  'endTimeSeconds': 0.5,
1353  'offsetSeconds': 0.0,
1354  'voiceIndex': None},
1355 {'durationSeconds': 0.5,
1356  'element': <music21.note.Note C>,
1357  'endTimeSeconds': 1.0,
1358  'offsetSeconds': 0.5,
1359  'voiceIndex': None},
1360 {'durationSeconds': 0.5,
1361  'element': <music21.note.Note D->,
1362  'endTimeSeconds': 1.5,
1363  'offsetSeconds': 1.0,
1364  'voiceIndex': None},
1365 {'durationSeconds': 0.5,
1366  'element': <music21.note.Note D#>,
1367  'endTimeSeconds': 2.0,
1368  'offsetSeconds': 1.5,
1369  'voiceIndex': None},
1370 {'durationSeconds': 0.5,
1371  'element': <music21.note.Note F#>,
1372  'endTimeSeconds': 2.5,
1373  'offsetSeconds': 2.0,
1374  'voiceIndex': None},
1375 {'durationSeconds': 0.5,
1376  'element': <music21.note.Note A#>,
1377  'endTimeSeconds': 3.0,
1378  'offsetSeconds': 2.5,
1379  'voiceIndex': None},
1380 {'durationSeconds': 0.5,
1381  'element': <music21.note.Note D#>,
1382  'endTimeSeconds': 3.5,
1383  'offsetSeconds': 3.0,
1384  'voiceIndex': None},
1385 {'durationSeconds': 0.5,
1386  'element': <music21.note.Note A>,
1387  'endTimeSeconds': 4.0,
1388  'offsetSeconds': 3.5,
1389  'voiceIndex': None}]''', pformat(di['secondsMap']))
1390
1391    def testDataSetOutput(self):
1392        from music21 import features
1393        from music21.features import outputFormats
1394        # test just a few features
1395        featureExtractors = features.extractorsById(['ql1', 'ql2', 'ql4'], 'native')
1396
1397        # need to define what the class label will be
1398        ds = features.DataSet(classLabel='Composer')
1399        ds.runParallel = False
1400        ds.addFeatureExtractors(featureExtractors)
1401
1402        # add works, defining the class value
1403        ds.addData('bwv66.6', classValue='Bach')
1404        ds.addData('corelli/opus3no1/1grave', classValue='Corelli')
1405
1406        ds.process()
1407
1408        # manually create an output format and get output
1409        of = outputFormats.OutputCSV(ds)
1410        post = of.getString(lineBreak='//')
1411        self.assertEqual(
1412            post,
1413            'Identifier,Unique_Note_Quarter_Lengths,'
1414            'Most_Common_Note_Quarter_Length,Range_of_Note_Quarter_Lengths,'
1415            'Composer//bwv66.6,3,1.0,1.5,Bach//corelli/opus3no1/1grave,8,0.5,3.75,Corelli')
1416
1417        # without id
1418        post = of.getString(lineBreak='//', includeId=False)
1419        self.assertEqual(
1420            post,
1421            'Unique_Note_Quarter_Lengths,Most_Common_Note_Quarter_Length,'
1422            'Range_of_Note_Quarter_Lengths,Composer//3,1.0,1.5,Bach//8,0.5,3.75,Corelli')
1423
1424        fp1 = ds.write(format='tab')
1425        fp2 = ds.write(format='csv')
1426        # Also test providing fp
1427        fp3 = environLocal.getTempFile(suffix='.arff')
1428        ds.write(fp=fp3, format='arff')
1429
1430        for fp in (fp1, fp2, fp3):
1431            os.remove(fp)
1432
1433    def testFeatureFail(self):
1434        from music21 import features
1435        from music21 import base
1436
1437        featureExtractors = ['p10', 'p11', 'p12', 'p13']
1438
1439        featureExtractors = features.extractorsById(featureExtractors,
1440                                                    'jSymbolic')
1441
1442        ds = features.DataSet(classLabel='Composer')
1443        ds.addFeatureExtractors(featureExtractors)
1444
1445        # create problematic streams
1446        s = stream.Stream()
1447        # s.append(None)  # will create a wrapper -- NOT ANYMORE
1448        s.append(base.ElementWrapper(None))
1449        ds.addData(s, classValue='Monteverdi')
1450        ds.addData(s, classValue='Handel')
1451
1452        # process with all feature extractors, store all features
1453        ds.failFast = True
1454        # Tests that some exception is raised, not necessarily that only one is
1455        with self.assertRaises(features.FeatureException):
1456            ds.process()
1457
1458    def testEmptyStreamCustomErrors(self):
1459        from music21 import analysis
1460        from music21 import features
1461        from music21.features import jSymbolic, native
1462
1463        ds = DataSet(classLabel='')
1464        f = list(jSymbolic.featureExtractors) + list(native.featureExtractors)
1465
1466        bareStream = stream.Stream()
1467        bareScore = stream.Score()
1468
1469        singlePart = stream.Part()
1470        singleMeasure = stream.Measure()
1471        singlePart.append(singleMeasure)
1472        bareScore.insert(singlePart)
1473
1474        ds.addData(bareStream)
1475        ds.addData(bareScore)
1476        ds.addFeatureExtractors(f)
1477
1478        for data in ds.dataInstances:
1479            for fe in ds._instantiatedFeatureExtractors:
1480                fe.setData(data)
1481                try:
1482                    fe.extract()
1483                # is every error wrapped?
1484                except (features.FeatureException,
1485                        analysis.discrete.DiscreteAnalysisException):
1486                    pass
1487
1488    # --------------------------------------------------------------------------
1489    # silent tests
1490
1491#    def testGetAllExtractorsMethods(self):
1492#        '''
1493#        ahh..this test takes a really long time....
1494#        '''
1495#        from music21 import stream, features, pitch
1496#        s = corpus.parse('bwv66.6').measures(1, 5)
1497#        self.assertEqual( len(features.alljSymbolicFeatures(s)), 70)
1498#        self.assertEqual(len (features.allNativeFeatures(s)),21)
1499#        self.assertEqual(str(features.alljSymbolicVectors(s)[1:5]),
1500# '[[2.6630434782608696], [2], [2], [0.391304347826087]]')
1501#        self.assertEqual(str(features.allNativeVectors(s)[0:4]),
1502# '[[1], [1.0328322202181006], [2], [1.0]]')
1503
1504    def x_testComposerClassificationJSymbolic(self):  # pragma: no cover
1505        '''
1506        Demonstrating writing out data files for feature extraction. Here,
1507        features are used from the jSymbolic library.
1508        '''
1509        from music21 import features
1510
1511        featureExtractors = ['r31', 'r32', 'r33', 'r34', 'r35', 'p1', 'p2', 'p3',
1512                             'p4', 'p5', 'p6', 'p7', 'p8', 'p9', 'p10', 'p11', 'p12',
1513                             'p13', 'p14', 'p15', 'p16', 'p19', 'p20', 'p21']
1514
1515        # will return a list
1516        featureExtractors = features.extractorsById(featureExtractors,
1517                                                    'jSymbolic')
1518
1519        # worksBach = corpus.getBachChorales()[100:143]  # a middle range
1520        worksMonteverdi = corpus.search('monteverdi').search('.xml')[:43]
1521
1522        worksBach = corpus.search('bach').search(numberOfParts=4)[:5]
1523
1524        # need to define what the class label will be
1525        ds = features.DataSet(classLabel='Composer')
1526        ds.addFeatureExtractors(featureExtractors)
1527
1528        # add works, defining the class value
1529#         for w in worksBach:
1530#             ds.addData(w, classValue='Bach')
1531        for w in worksMonteverdi:
1532            ds.addData(w, classValue='Monteverdi')
1533        for w in worksBach:
1534            ds.addData(w, classValue='Bach')
1535
1536        # process with all feature extractors, store all features
1537        ds.process()
1538        ds.write(format='tab')
1539        ds.write(format='csv')
1540        ds.write(format='arff')
1541
1542    def x_testRegionClassificationJSymbolicA(self):  # pragma: no cover
1543        '''
1544        Demonstrating writing out data files for feature extraction. Here,
1545        features are used from the jSymbolic library.
1546        '''
1547        from music21 import features
1548
1549        featureExtractors = features.extractorsById(['r31', 'r32', 'r33', 'r34', 'r35',
1550                                                     'p1', 'p2', 'p3', 'p4', 'p5', 'p6',
1551                                                     'p7', 'p8', 'p9', 'p10', 'p11', 'p12',
1552                                                     'p13', 'p14', 'p15', 'p16', 'p19',
1553                                                     'p20', 'p21'], 'jSymbolic')
1554
1555        oChina1 = corpus.parse('essenFolksong/han1')
1556        oChina2 = corpus.parse('essenFolksong/han2')
1557
1558        oMitteleuropa1 = corpus.parse('essenFolksong/boehme10')
1559        oMitteleuropa2 = corpus.parse('essenFolksong/boehme20')
1560
1561        ds = features.DataSet(classLabel='Region')
1562        ds.addFeatureExtractors(featureExtractors)
1563
1564        # add works, defining the class value
1565        for o, name in [(oChina1, 'han1'),
1566                        (oChina2, 'han2')]:
1567            for w in o.scores:
1568                songId = f'essenFolksong/{name}-{w.metadata.number}'
1569                ds.addData(w, classValue='China', id=songId)
1570
1571        for o, name in [(oMitteleuropa1, 'boehme10'),
1572                        (oMitteleuropa2, 'boehme20')]:
1573            for w in o.scores:
1574                songId = f'essenFolksong/{name}-{w.metadata.number}'
1575                ds.addData(w, classValue='Mitteleuropa', id=songId)
1576
1577        # process with all feature extractors, store all features
1578        ds.process()
1579        ds.getString(outputFmt='tab')
1580        ds.getString(outputFmt='csv')
1581        ds.getString(outputFmt='arff')
1582
1583    def x_testRegionClassificationJSymbolicB(self):  # pragma: no cover
1584        '''
1585        Demonstrating writing out data files for feature extraction.
1586        Here, features are used from the jSymbolic library.
1587        '''
1588        from music21 import features
1589
1590        # features common to both collections
1591        featureExtractors = features.extractorsById(
1592            ['r31', 'r32', 'r33', 'r34', 'r35', 'p1', 'p2', 'p3', 'p4',
1593                             'p5', 'p6', 'p7', 'p8', 'p9', 'p10', 'p11', 'p12', 'p13',
1594                             'p14', 'p15', 'p16', 'p19', 'p20', 'p21'], 'jSymbolic')
1595
1596        # first bundle
1597        ds = features.DataSet(classLabel='Region')
1598        ds.addFeatureExtractors(featureExtractors)
1599
1600        oChina1 = corpus.parse('essenFolksong/han1')
1601        oMitteleuropa1 = corpus.parse('essenFolksong/boehme10')
1602
1603        # add works, defining the class value
1604        for o, name in [(oChina1, 'han1')]:
1605            for w in o.scores:
1606                songId = f'essenFolksong/{name}-{w.metadata.number}'
1607                ds.addData(w, classValue='China', id=songId)
1608
1609        for o, name in [(oMitteleuropa1, 'boehme10')]:
1610            for w in o.scores:
1611                songId = f'essenFolksong/{name}-{w.metadata.number}'
1612                ds.addData(w, classValue='Mitteleuropa', id=songId)
1613
1614        # process with all feature extractors, store all features
1615        ds.process()
1616        ds.write('/_scratch/chinaMitteleuropaSplit-a.tab')
1617        ds.write('/_scratch/chinaMitteleuropaSplit-a.csv')
1618        ds.write('/_scratch/chinaMitteleuropaSplit-a.arff')
1619
1620        # create second data set from alternate collections
1621        ds = features.DataSet(classLabel='Region')
1622        ds.addFeatureExtractors(featureExtractors)
1623
1624        oChina2 = corpus.parse('essenFolksong/han2')
1625        oMitteleuropa2 = corpus.parse('essenFolksong/boehme20')
1626        # add works, defining the class value
1627        for o, name in [(oChina2, 'han2')]:
1628            for w in o.scores:
1629                songId = f'essenFolksong/{name}-{w.metadata.number}'
1630                ds.addData(w, classValue='China', id=songId)
1631
1632        for o, name in [(oMitteleuropa2, 'boehme20')]:
1633            for w in o.scores:
1634                songId = f'essenFolksong/{name}-{w.metadata.number}'
1635                ds.addData(w, classValue='Mitteleuropa', id=songId)
1636
1637        # process with all feature extractors, store all features
1638        ds.process()
1639        ds.write('/_scratch/chinaMitteleuropaSplit-b.tab')
1640        ds.write('/_scratch/chinaMitteleuropaSplit-b.csv')
1641        ds.write('/_scratch/chinaMitteleuropaSplit-b.arff')
1642
1643# all these are written using orange-Py2 code; need better.
1644#     def xtestOrangeBayesA(self):  # pragma: no cover
1645#         '''Using an already created test file with a BayesLearner.
1646#         '''
1647#         import orange # @UnresolvedImport  # pylint: disable=import-error
1648#         data = orange.ExampleTable(
1649#             '~/music21Ext/mlDataSets/bachMonteverdi-a/bachMonteverdi-a.tab')
1650#         classifier = orange.BayesLearner(data)
1651#         for i in range(len(data)):
1652#             c = classifier(data[i])
1653#             print('original', data[i].getclass(), 'BayesLearner:', c)
1654#
1655#
1656#     def xtestClassifiersA(self):  # pragma: no cover
1657#         '''Using an already created test file with a BayesLearner.
1658#         '''
1659#         import orange, orngTree # @UnresolvedImport  # pylint: disable=import-error
1660#         data1 = orange.ExampleTable(
1661#                 '~/music21Ext/mlDataSets/chinaMitteleuropa-b/chinaMitteleuropa-b1.tab')
1662#
1663#         data2 = orange.ExampleTable(
1664#                 '~/music21Ext/mlDataSets/chinaMitteleuropa-b/chinaMitteleuropa-b2.tab')
1665#
1666#         majority = orange.MajorityLearner
1667#         bayes = orange.BayesLearner
1668#         tree = orngTree.TreeLearner
1669#         knn = orange.kNNLearner
1670#
1671#         for classifierType in [majority, bayes, tree, knn]:
1672#             print('')
1673#             for classifierData, classifierStr, matchData, matchStr in [
1674#                 (data1, 'data1', data1, 'data1'),
1675#                 (data1, 'data1', data2, 'data2'),
1676#                 (data2, 'data2', data2, 'data2'),
1677#                 (data2, 'data2', data1, 'data1'),
1678#                 ]:
1679#
1680#                 # train with data1
1681#                 classifier = classifierType(classifierData)
1682#                 mismatch = 0
1683#                 for i in range(len(matchData)):
1684#                     c = classifier(matchData[i])
1685#                     if c != matchData[i].getclass():
1686#                         mismatch += 1
1687#
1688#                 print('%s %s: misclassified %s/%s of %s' % (
1689#                         classifierStr, classifierType, mismatch, len(matchData), matchStr))
1690#
1691# #             if classifierType == orngTree.TreeLearner:
1692# #                 orngTree.printTxt(classifier)
1693#
1694#
1695#
1696#     def xtestClassifiersB(self):  # pragma: no cover
1697#         '''Using an already created test file with a BayesLearner.
1698#         '''
1699#         import orange, orngTree # @UnresolvedImport  # pylint: disable=import-error
1700#         data1 = orange.ExampleTable(
1701#                 '~/music21Ext/mlDataSets/chinaMitteleuropa-b/chinaMitteleuropa-b1.tab')
1702#
1703#         data2 = orange.ExampleTable(
1704#                 '~/music21Ext/mlDataSets/chinaMitteleuropa-b/chinaMitteleuropa-b2.tab',
1705#                 use=data1.domain)
1706#
1707#         data1.extend(data2)
1708#         data = data1
1709#
1710#         majority = orange.MajorityLearner
1711#         bayes = orange.BayesLearner
1712#         tree = orngTree.TreeLearner
1713#         knn = orange.kNNLearner
1714#
1715#         folds = 10
1716#         for classifierType in [majority, bayes, tree, knn]:
1717#             print('')
1718#
1719#             cvIndices = orange.MakeRandomIndicesCV(data, folds)
1720#             for fold in range(folds):
1721#                 train = data.select(cvIndices, fold, negate=1)
1722#                 test = data.select(cvIndices, fold)
1723#
1724#                 for classifierData, classifierStr, matchData, matchStr in [
1725#                     (train, 'train', test, 'test'),
1726#                     ]:
1727#
1728#                     # train with data1
1729#                     classifier = classifierType(classifierData)
1730#                     mismatch = 0
1731#                     for i in range(len(matchData)):
1732#                         c = classifier(matchData[i])
1733#                         if c != matchData[i].getclass():
1734#                             mismatch += 1
1735#
1736#                     print('%s %s: misclassified %s/%s of %s' % (
1737#                             classifierStr, classifierType, mismatch, len(matchData), matchStr))
1738#
1739#
1740#     def xtestOrangeClassifiers(self):  # pragma: no cover
1741#         '''
1742#         This test shows how to compare four classifiers; replace the file path
1743#         with a path to the .tab data file.
1744#         '''
1745#         import orange, orngTree # @UnresolvedImport  # pylint: disable=import-error
1746#         data = orange.ExampleTable(
1747#             '~/music21Ext/mlDataSets/bachMonteverdi-a/bachMonteverdi-a.tab')
1748#
1749#         # setting up the classifiers
1750#         majority = orange.MajorityLearner(data)
1751#         bayes = orange.BayesLearner(data)
1752#         tree = orngTree.TreeLearner(data, sameMajorityPruning=1, mForPruning=2)
1753#         knn = orange.kNNLearner(data, k=21)
1754#
1755#         majority.name='Majority'
1756#         bayes.name='Naive Bayes'
1757#         tree.name='Tree'
1758#         knn.name='kNN'
1759#         classifiers = [majority, bayes, tree, knn]
1760#
1761#         # print the head
1762#         print('Possible classes:', data.domain.classVar.values)
1763#         print('Original Class', end=' ')
1764#         for l in classifiers:
1765#             print('%-13s' % (l.name), end=' ')
1766#         print()
1767#
1768#         for example in data:
1769#             print('(%-10s)  ' % (example.getclass()), end=' ')
1770#             for c in classifiers:
1771#                 p = c([example, orange.GetProbabilities])
1772#                 print('%5.3f        ' % (p[0]), end=' ')
1773#             print('')
1774#
1775#
1776#     def xtestOrangeClassifierTreeLearner(self):  # pragma: no cover
1777#         import orange, orngTree # @UnresolvedImport  # pylint: disable=import-error
1778#         data = orange.ExampleTable(
1779#             '~/music21Ext/mlDataSets/bachMonteverdi-a/bachMonteverdi-a.tab')
1780#
1781#         tree = orngTree.TreeLearner(data, sameMajorityPruning=1, mForPruning=2)
1782#         # tree = orngTree.TreeLearner(data)
1783#         for i in range(len(data)):
1784#             p = tree(data[i], orange.GetProbabilities)
1785#             print('%s: %5.3f (originally %s)' % (i + 1, p[1], data[i].getclass()))
1786#
1787#         orngTree.printTxt(tree)
1788
1789    def testParallelRun(self):
1790        from music21 import features
1791        # test just a few features
1792        featureExtractors = features.extractorsById(['ql1', 'ql2', 'ql4'], 'native')
1793
1794        # need to define what the class label will be
1795        ds = features.DataSet(classLabel='Composer')
1796        ds.addFeatureExtractors(featureExtractors)
1797
1798        # add works, defining the class value
1799        ds.addData('bwv66.6', classValue='Bach')
1800        ds.addData('corelli/opus3no1/1grave', classValue='Corelli')
1801        ds.runParallel = True
1802        ds.quiet = True
1803        ds.process()
1804        self.assertEqual(len(ds.features), 2)
1805        self.assertEqual(len(ds.features[0]), 3)
1806        fe00 = ds.features[0][0]
1807        self.assertEqual(fe00.vector, [3])
1808
1809    # pylint: disable=redefined-outer-name
1810    def x_fix_parallel_first_testMultipleSearches(self):
1811        from music21.features import outputFormats
1812        from music21 import features
1813
1814        # Need explicit import for pickling within the testSingleCoreAll context
1815        from music21.features.base import _pickleFunctionNumPitches
1816        import textwrap
1817
1818        self.maxDiff = None
1819
1820        fewBach = corpus.search('bach/bwv6')
1821
1822        self.assertEqual(len(fewBach), 13)
1823        ds = features.DataSet(classLabel='NumPitches')
1824        ds.addMultipleData(fewBach, classValues=_pickleFunctionNumPitches)
1825        featureExtractors = features.extractorsById(['ql1', 'ql4'], 'native')
1826        ds.addFeatureExtractors(featureExtractors)
1827        ds.runParallel = True
1828        ds.process()
1829        # manually create an output format and get output
1830        of = outputFormats.OutputCSV(ds)
1831        post = of.getString(lineBreak='\n')
1832        self.assertEqual(post.strip(), textwrap.dedent('''
1833            Identifier,Unique_Note_Quarter_Lengths,Range_of_Note_Quarter_Lengths,NumPitches
1834            bach/bwv6.6.mxl,4,1.75,164
1835            bach/bwv60.5.mxl,6,2.75,282
1836            bach/bwv62.6.mxl,5,1.75,182
1837            bach/bwv64.2.mxl,4,1.5,179
1838            bach/bwv64.4.mxl,5,2.5,249
1839            bach/bwv64.8.mxl,5,3.5,188
1840            bach/bwv65.2.mxl,4,3.0,148
1841            bach/bwv65.7.mxl,7,2.75,253
1842            bach/bwv66.6.mxl,3,1.5,165
1843            bach/bwv67.4.xml,3,1.5,173
1844            bach/bwv67.7.mxl,4,2.5,132
1845            bach/bwv69.6-a.mxl,4,1.5,170
1846            bach/bwv69.6.xml,8,4.25,623
1847            ''').strip())
1848
1849
1850def _pickleFunctionNumPitches(bachStream):
1851    '''
1852    A function for documentation testing of a pickleable function
1853    '''
1854    return len(bachStream.pitches)
1855
1856
1857# ------------------------------------------------------------------------------
1858# define presented order in documentation
1859_DOC_ORDER = [DataSet, Feature, FeatureExtractor]
1860
1861
1862if __name__ == '__main__':
1863    import music21
1864    music21.mainTest(Test)  # , runTest='testStreamFormsA')
1865
1866