1"""The Instrument class holds all events for a single instrument and contains
2functions for extracting information from the events it contains.
3"""
4import numpy as np
5try:
6    import fluidsynth
7    _HAS_FLUIDSYNTH = True
8except ImportError:
9    _HAS_FLUIDSYNTH = False
10import os
11import pkg_resources
12
13from .containers import PitchBend
14from .utilities import pitch_bend_to_semitones, note_number_to_hz
15
16DEFAULT_SF2 = 'TimGM6mb.sf2'
17
18
19class Instrument(object):
20    """Object to hold event information for a single instrument.
21
22    Parameters
23    ----------
24    program : int
25        MIDI program number (instrument index), in ``[0, 127]``.
26    is_drum : bool
27        Is the instrument a drum instrument (channel 9)?
28    name : str
29        Name of the instrument.
30
31    Attributes
32    ----------
33    program : int
34        The program number of this instrument.
35    is_drum : bool
36        Is the instrument a drum instrument (channel 9)?
37    name : str
38        Name of the instrument.
39    notes : list
40        List of :class:`pretty_midi.Note` objects.
41    pitch_bends : list
42        List of of :class:`pretty_midi.PitchBend` objects.
43    control_changes : list
44        List of :class:`pretty_midi.ControlChange` objects.
45
46    """
47
48    def __init__(self, program, is_drum=False, name=''):
49        """Create the Instrument.
50
51        """
52        self.program = program
53        self.is_drum = is_drum
54        self.name = name
55        self.notes = []
56        self.pitch_bends = []
57        self.control_changes = []
58
59    def get_onsets(self):
60        """Get all onsets of all notes played by this instrument.
61        May contain duplicates.
62
63        Returns
64        -------
65        onsets : np.ndarray
66                List of all note onsets.
67
68        """
69        onsets = []
70        # Get the note-on time of each note played by this instrument
71        for note in self.notes:
72            onsets.append(note.start)
73        # Return them sorted (because why not?)
74        return np.sort(onsets)
75
76    def get_piano_roll(self, fs=100, times=None,
77                       pedal_threshold=64):
78        """Compute a piano roll matrix of this instrument.
79
80        Parameters
81        ----------
82        fs : int
83            Sampling frequency of the columns, i.e. each column is spaced apart
84            by ``1./fs`` seconds.
85        times : np.ndarray
86            Times of the start of each column in the piano roll.
87            Default ``None`` which is ``np.arange(0, get_end_time(), 1./fs)``.
88        pedal_threshold : int
89            Value of control change 64 (sustain pedal) message that is less
90            than this value is reflected as pedal-off.  Pedals will be
91            reflected as elongation of notes in the piano roll.
92            If None, then CC64 message is ignored.
93            Default is 64.
94
95        Returns
96        -------
97        piano_roll : np.ndarray, shape=(128,times.shape[0])
98            Piano roll of this instrument.
99
100        """
101        # If there are no notes, return an empty matrix
102        if self.notes == []:
103            return np.array([[]]*128)
104        # Get the end time of the last event
105        end_time = self.get_end_time()
106        # Extend end time if one was provided
107        if times is not None and times[-1] > end_time:
108            end_time = times[-1]
109        # Allocate a matrix of zeros - we will add in as we go
110        piano_roll = np.zeros((128, int(fs*end_time)))
111        # Drum tracks don't have pitch, so return a matrix of zeros
112        if self.is_drum:
113            if times is None:
114                return piano_roll
115            else:
116                return np.zeros((128, times.shape[0]))
117        # Add up piano roll matrix, note-by-note
118        for note in self.notes:
119            # Should interpolate
120            piano_roll[note.pitch,
121                       int(note.start*fs):int(note.end*fs)] += note.velocity
122
123        # Process sustain pedals
124        if pedal_threshold is not None:
125            CC_SUSTAIN_PEDAL = 64
126            time_pedal_on = 0
127            is_pedal_on = False
128            for cc in [_e for _e in self.control_changes
129                       if _e.number == CC_SUSTAIN_PEDAL]:
130                time_now = int(cc.time*fs)
131                is_current_pedal_on = (cc.value >= pedal_threshold)
132                if not is_pedal_on and is_current_pedal_on:
133                    time_pedal_on = time_now
134                    is_pedal_on = True
135                elif is_pedal_on and not is_current_pedal_on:
136                    # For each pitch, a sustain pedal "retains"
137                    # the maximum velocity up to now due to
138                    # logarithmic nature of human loudness perception
139                    subpr = piano_roll[:, time_pedal_on:time_now]
140
141                    # Take the running maximum
142                    pedaled = np.maximum.accumulate(subpr, axis=1)
143                    piano_roll[:, time_pedal_on:time_now] = pedaled
144                    is_pedal_on = False
145
146        # Process pitch changes
147        # Need to sort the pitch bend list for the following to work
148        ordered_bends = sorted(self.pitch_bends, key=lambda bend: bend.time)
149        # Add in a bend of 0 at the end of time
150        end_bend = PitchBend(0, end_time)
151        for start_bend, end_bend in zip(ordered_bends,
152                                        ordered_bends[1:] + [end_bend]):
153            # Piano roll is already generated with everything bend = 0
154            if np.abs(start_bend.pitch) < 1:
155                continue
156            # Get integer and decimal part of bend amount
157            start_pitch = pitch_bend_to_semitones(start_bend.pitch)
158            bend_int = int(np.sign(start_pitch)*np.floor(np.abs(start_pitch)))
159            bend_decimal = np.abs(start_pitch - bend_int)
160            # Column indices effected by the bend
161            bend_range = np.r_[int(start_bend.time*fs):int(end_bend.time*fs)]
162            # Construct the bent part of the piano roll
163            bent_roll = np.zeros(piano_roll[:, bend_range].shape)
164            # Easiest to process differently depending on bend sign
165            if start_bend.pitch >= 0:
166                # First, pitch shift by the int amount
167                if bend_int is not 0:
168                    bent_roll[bend_int:] = piano_roll[:-bend_int, bend_range]
169                else:
170                    bent_roll = piano_roll[:, bend_range]
171                # Now, linear interpolate by the decimal place
172                bent_roll[1:] = ((1 - bend_decimal)*bent_roll[1:] +
173                                 bend_decimal*bent_roll[:-1])
174            else:
175                # Same procedure as for positive bends
176                if bend_int is not 0:
177                    bent_roll[:bend_int] = piano_roll[-bend_int:, bend_range]
178                else:
179                    bent_roll = piano_roll[:, bend_range]
180                bent_roll[:-1] = ((1 - bend_decimal)*bent_roll[:-1] +
181                                  bend_decimal*bent_roll[1:])
182            # Store bent portion back in piano roll
183            piano_roll[:, bend_range] = bent_roll
184
185        if times is None:
186            return piano_roll
187        piano_roll_integrated = np.zeros((128, times.shape[0]))
188        # Convert to column indices
189        times = np.array(np.round(times*fs), dtype=np.int)
190        for n, (start, end) in enumerate(zip(times[:-1], times[1:])):
191            if start < piano_roll.shape[1]:  # if start is >=, leave zeros
192                if start == end:
193                    end = start + 1
194                # Each column is the mean of the columns in piano_roll
195                piano_roll_integrated[:, n] = np.mean(piano_roll[:, start:end],
196                                                      axis=1)
197        return piano_roll_integrated
198
199    def get_chroma(self, fs=100, times=None, pedal_threshold=64):
200        """Get a sequence of chroma vectors from this instrument.
201
202        Parameters
203        ----------
204        fs : int
205            Sampling frequency of the columns, i.e. each column is spaced apart
206            by ``1./fs`` seconds.
207        times : np.ndarray
208            Times of the start of each column in the piano roll.
209            Default ``None`` which is ``np.arange(0, get_end_time(), 1./fs)``.
210        pedal_threshold : int
211            Value of control change 64 (sustain pedal) message that is less
212            than this value is reflected as pedal-off.  Pedals will be
213            reflected as elongation of notes in the piano roll.
214            If None, then CC64 message is ignored.
215            Default is 64.
216
217        Returns
218        -------
219        piano_roll : np.ndarray, shape=(12,times.shape[0])
220            Chromagram of this instrument.
221
222        """
223        # First, get the piano roll
224        piano_roll = self.get_piano_roll(fs=fs, times=times,
225                                         pedal_threshold=pedal_threshold)
226        # Fold into one octave
227        chroma_matrix = np.zeros((12, piano_roll.shape[1]))
228        for note in range(12):
229            chroma_matrix[note, :] = np.sum(piano_roll[note::12], axis=0)
230        return chroma_matrix
231
232    def get_end_time(self):
233        """Returns the time of the end of the events in this instrument.
234
235        Returns
236        -------
237        end_time : float
238            Time, in seconds, of the last event.
239
240        """
241        # Cycle through all note ends and all pitch bends and find the largest
242        events = ([n.end for n in self.notes] +
243                  [b.time for b in self.pitch_bends] +
244                  [c.time for c in self.control_changes])
245        # If there are no events, just return 0
246        if len(events) == 0:
247            return 0.
248        else:
249            return max(events)
250
251    def get_pitch_class_histogram(self, use_duration=False, use_velocity=False,
252                                  normalize=False):
253        """Computes the frequency of pitch classes of this instrument,
254        optionally weighted by their durations or velocities.
255
256        Parameters
257        ----------
258        use_duration : bool
259            Weight frequency by note duration.
260        use_velocity : bool
261            Weight frequency by note velocity.
262        normalize : bool
263            Normalizes the histogram such that the sum of bin values is 1.
264
265        Returns
266        -------
267        histogram : np.ndarray, shape=(12,)
268            Histogram of pitch classes given current instrument, optionally
269            weighted by their durations or velocities.
270        """
271
272        # Return all zeros if track is drum
273        if self.is_drum:
274            return np.zeros(12)
275
276        weights = np.ones(len(self.notes))
277
278        # Assumes that duration and velocity have equal weight
279        if use_duration:
280            weights *= [note.end - note.start for note in self.notes]
281        if use_velocity:
282            weights *= [note.velocity for note in self.notes]
283
284        histogram, _ = np.histogram([n.pitch % 12 for n in self.notes],
285                                    bins=np.arange(13),
286                                    weights=weights,
287                                    density=normalize)
288
289        return histogram
290
291    def get_pitch_class_transition_matrix(self, normalize=False,
292                                          time_thresh=0.05):
293        """Computes the pitch class transition matrix of this instrument.
294        Transitions are added whenever the end of a note is within
295        ``time_tresh`` from the start of any other note.
296
297        Parameters
298        ----------
299        normalize : bool
300            Normalize transition matrix such that matrix sum equals to 1.
301        time_thresh : float
302            Maximum temporal threshold, in seconds, between the start of a note
303            and end time of any other note for a transition to be added.
304
305        Returns
306        -------
307        transition_matrix : np.ndarray, shape=(12,12)
308            Pitch class transition matrix.
309        """
310
311        # instrument is drum or less than one note, return all zeros
312        if self.is_drum or len(self.notes) <= 1:
313            return np.zeros((12, 12))
314
315        # retrieve note starts, ends and pitch classes(nodes) from self.notes
316        starts, ends, nodes = np.array(
317            [[x.start, x.end, x.pitch % 12] for x in self.notes]).T
318
319        # compute distance matrix for all start and end time pairs
320        dist_mat = np.subtract.outer(ends, starts)
321
322        # find indices of pairs of notes where the end time of one note is
323        # within time_thresh of the start time of the other
324        sources, targets = np.where(abs(dist_mat) < time_thresh)
325
326        transition_matrix, _, _ = np.histogram2d(nodes[sources],
327                                                 nodes[targets],
328                                                 bins=np.arange(13),
329                                                 normed=normalize)
330        return transition_matrix
331
332    def remove_invalid_notes(self):
333        """Removes any notes whose end time is before or at their start time.
334
335        """
336        # Crete a list of all invalid notes
337        notes_to_delete = []
338        for note in self.notes:
339            if note.end <= note.start:
340                notes_to_delete.append(note)
341        # Remove the notes found
342        for note in notes_to_delete:
343            self.notes.remove(note)
344
345    def synthesize(self, fs=44100, wave=np.sin):
346        """Synthesize the instrument's notes using some waveshape.
347        For drum instruments, returns zeros.
348
349        Parameters
350        ----------
351        fs : int
352            Sampling rate of the synthesized audio signal.
353        wave : function
354            Function which returns a periodic waveform,
355            e.g. ``np.sin``, ``scipy.signal.square``, etc.
356
357        Returns
358        -------
359        synthesized : np.ndarray
360            Waveform of the instrument's notes, synthesized at ``fs``.
361
362        """
363        # Pre-allocate output waveform
364        synthesized = np.zeros(int(fs*(self.get_end_time() + 1)))
365
366        # If we're a percussion channel, just return the zeros
367        if self.is_drum:
368            return synthesized
369        # If the above if statement failed, we need to revert back to default
370        if not hasattr(wave, '__call__'):
371            raise ValueError('wave should be a callable Python function')
372        # This is a simple way to make the end of the notes fade-out without
373        # clicks
374        fade_out = np.linspace(1, 0, int(.1*fs))
375        # Create a frequency multiplier array for pitch bend
376        bend_multiplier = np.ones(synthesized.shape)
377        # Need to sort the pitch bend list for the loop below to work
378        ordered_bends = sorted(self.pitch_bends, key=lambda bend: bend.time)
379        # Add in a bend of 0 at the end of time
380        end_bend = PitchBend(0, self.get_end_time())
381        for start_bend, end_bend in zip(ordered_bends,
382                                        ordered_bends[1:] + [end_bend]):
383            # Bend start and end time in samples
384            start = int(start_bend.time*fs)
385            end = int(end_bend.time*fs)
386            # The multiplier will be (twelfth root of 2)^(bend semitones)
387            bend_semitones = pitch_bend_to_semitones(start_bend.pitch)
388            bend_amount = (2**(1/12.))**bend_semitones
389            # Sample indices effected by the bend
390            bend_multiplier[start:end] = bend_amount
391        # Add in waveform for each note
392        for note in self.notes:
393            # Indices in samples of this note
394            start = int(fs*note.start)
395            end = int(fs*note.end)
396            # Get frequency of note from MIDI note number
397            frequency = note_number_to_hz(note.pitch)
398            # When a pitch bend gets applied, there will be a sample
399            # discontinuity. So, we also need an array of offsets which get
400            # applied to compensate.
401            offsets = np.zeros(end - start)
402            for bend in ordered_bends:
403                bend_sample = int(bend.time*fs)
404                # Does this pitch bend fall within this note?
405                if bend_sample > start and bend_sample < end:
406                    # Compute the average bend so far
407                    bend_so_far = bend_multiplier[start:bend_sample].mean()
408                    bend_amount = bend_multiplier[bend_sample]
409                    # Compute the offset correction
410                    offset = (bend_so_far - bend_amount)*(bend_sample - start)
411                    # Store this offset for samples effected
412                    offsets[bend_sample - start:] = offset
413            # Compute the angular frequencies, bent, over this interval
414            frequencies = 2*np.pi*frequency*(bend_multiplier[start:end])/fs
415            # Synthesize using wave function at this frequency
416            note_waveform = wave(frequencies*np.arange(end - start) +
417                                 2*np.pi*frequency*offsets/fs)
418            # Apply an exponential envelope
419            envelope = np.exp(-np.arange(end - start)/(1.0*fs))
420            # Make the end of the envelope be a fadeout
421            if envelope.shape[0] > fade_out.shape[0]:
422                envelope[-fade_out.shape[0]:] *= fade_out
423            else:
424                envelope *= np.linspace(1, 0, envelope.shape[0])
425            # Multiply by velocity (don't think it's linearly scaled but
426            # whatever)
427            envelope *= note.velocity
428            # Add in envelope'd waveform to the synthesized signal
429            synthesized[start:end] += envelope*note_waveform
430
431        return synthesized
432
433    def fluidsynth(self, fs=44100, sf2_path=None):
434        """Synthesize using fluidsynth.
435
436        Parameters
437        ----------
438        fs : int
439            Sampling rate to synthesize.
440        sf2_path : str
441            Path to a .sf2 file.
442            Default ``None``, which uses the TimGM6mb.sf2 file included with
443            ``pretty_midi``.
444
445        Returns
446        -------
447        synthesized : np.ndarray
448            Waveform of the MIDI data, synthesized at ``fs``.
449
450        """
451        # If sf2_path is None, use the included TimGM6mb.sf2 path
452        if sf2_path is None:
453            sf2_path = pkg_resources.resource_filename(__name__, DEFAULT_SF2)
454
455        if not _HAS_FLUIDSYNTH:
456            raise ImportError("fluidsynth() was called but pyfluidsynth "
457                              "is not installed.")
458
459        if not os.path.exists(sf2_path):
460            raise ValueError("No soundfont file found at the supplied path "
461                             "{}".format(sf2_path))
462
463        # If the instrument has no notes, return an empty array
464        if len(self.notes) == 0:
465            return np.array([])
466
467        # Create fluidsynth instance
468        fl = fluidsynth.Synth(samplerate=fs)
469        # Load in the soundfont
470        sfid = fl.sfload(sf2_path)
471        # If this is a drum instrument, use channel 9 and bank 128
472        if self.is_drum:
473            channel = 9
474            # Try to use the supplied program number
475            res = fl.program_select(channel, sfid, 128, self.program)
476            # If the result is -1, there's no preset with this program number
477            if res == -1:
478                # So use preset 0
479                fl.program_select(channel, sfid, 128, 0)
480        # Otherwise just use channel 0
481        else:
482            channel = 0
483            fl.program_select(channel, sfid, 0, self.program)
484        # Collect all notes in one list
485        event_list = []
486        for note in self.notes:
487            event_list += [[note.start, 'note on', note.pitch, note.velocity]]
488            event_list += [[note.end, 'note off', note.pitch]]
489        for bend in self.pitch_bends:
490            event_list += [[bend.time, 'pitch bend', bend.pitch]]
491        for control_change in self.control_changes:
492            event_list += [[control_change.time, 'control change',
493                            control_change.number, control_change.value]]
494        # Sort the event list by time, and secondarily by whether the event
495        # is a note off
496        event_list.sort(key=lambda x: (x[0], x[1] != 'note off'))
497        # Add some silence at the beginning according to the time of the first
498        # event
499        current_time = event_list[0][0]
500        # Convert absolute seconds to relative samples
501        next_event_times = [e[0] for e in event_list[1:]]
502        for event, end in zip(event_list[:-1], next_event_times):
503            event[0] = end - event[0]
504        # Include 1 second of silence at the end
505        event_list[-1][0] = 1.
506        # Pre-allocate output array
507        total_time = current_time + np.sum([e[0] for e in event_list])
508        synthesized = np.zeros(int(np.ceil(fs*total_time)))
509        # Iterate over all events
510        for event in event_list:
511            # Process events based on type
512            if event[1] == 'note on':
513                fl.noteon(channel, event[2], event[3])
514            elif event[1] == 'note off':
515                fl.noteoff(channel, event[2])
516            elif event[1] == 'pitch bend':
517                fl.pitch_bend(channel, event[2])
518            elif event[1] == 'control change':
519                fl.cc(channel, event[2], event[3])
520            # Add in these samples
521            current_sample = int(fs*current_time)
522            end = int(fs*(current_time + event[0]))
523            samples = fl.get_samples(end - current_sample)[::2]
524            synthesized[current_sample:end] += samples
525            # Increment the current sample
526            current_time += event[0]
527        # Close fluidsynth
528        fl.delete()
529
530        return synthesized
531
532    def __repr__(self):
533        return 'Instrument(program={}, is_drum={}, name="{}")'.format(
534            self.program, self.is_drum, self.name.replace('"', r'\"'))
535