1"""Stuff to parse WAVE files.
2
3Usage.
4
5Reading WAVE files:
6      f = wave.open(file, 'r')
7where file is either the name of a file or an open file pointer.
8The open file pointer must have methods read(), seek(), and close().
9When the setpos() and rewind() methods are not used, the seek()
10method is not  necessary.
11
12This returns an instance of a class with the following public methods:
13      getnchannels()  -- returns number of audio channels (1 for
14                         mono, 2 for stereo)
15      getsampwidth()  -- returns sample width in bytes
16      getframerate()  -- returns sampling frequency
17      getnframes()    -- returns number of audio frames
18      getcomptype()   -- returns compression type ('NONE' for linear samples)
19      getcompname()   -- returns human-readable version of
20                         compression type ('not compressed' linear samples)
21      getparams()     -- returns a namedtuple consisting of all of the
22                         above in the above order
23      getmarkers()    -- returns None (for compatibility with the
24                         aifc module)
25      getmark(id)     -- raises an error since the mark does not
26                         exist (for compatibility with the aifc module)
27      readframes(n)   -- returns at most n frames of audio
28      rewind()        -- rewind to the beginning of the audio stream
29      setpos(pos)     -- seek to the specified position
30      tell()          -- return the current position
31      close()         -- close the instance (make it unusable)
32The position returned by tell() and the position given to setpos()
33are compatible and have nothing to do with the actual position in the
34file.
35The close() method is called automatically when the class instance
36is destroyed.
37
38Writing WAVE files:
39      f = wave.open(file, 'w')
40where file is either the name of a file or an open file pointer.
41The open file pointer must have methods write(), tell(), seek(), and
42close().
43
44This returns an instance of a class with the following public methods:
45      setnchannels(n) -- set the number of channels
46      setsampwidth(n) -- set the sample width
47      setframerate(n) -- set the frame rate
48      setnframes(n)   -- set the number of frames
49      setcomptype(type, name)
50                      -- set the compression type and the
51                         human-readable compression type
52      setparams(tuple)
53                      -- set all parameters at once
54      tell()          -- return current position in output file
55      writeframesraw(data)
56                      -- write audio frames without patching up the
57                         file header
58      writeframes(data)
59                      -- write audio frames and patch up the file header
60      close()         -- patch up the file header and close the
61                         output file
62You should set the parameters before the first writeframesraw or
63writeframes.  The total number of frames does not need to be set,
64but when it is set to the correct value, the header does not have to
65be patched up.
66It is best to first set all parameters, perhaps possibly the
67compression type, and then write audio frames using writeframesraw.
68When all frames have been written, either call writeframes(b'') or
69close() to patch up the sizes in the header.
70The close() method is called automatically when the class instance
71is destroyed.
72"""
73
74from chunk import Chunk
75from collections import namedtuple
76import audioop
77import builtins
78import struct
79import sys
80
81
82__all__ = ["open", "Error", "Wave_read", "Wave_write"]
83
84class Error(Exception):
85    pass
86
87WAVE_FORMAT_PCM = 0x0001
88
89_array_fmts = None, 'b', 'h', None, 'i'
90
91_wave_params = namedtuple('_wave_params',
92                     'nchannels sampwidth framerate nframes comptype compname')
93
94class Wave_read:
95    """Variables used in this class:
96
97    These variables are available to the user though appropriate
98    methods of this class:
99    _file -- the open file with methods read(), close(), and seek()
100              set through the __init__() method
101    _nchannels -- the number of audio channels
102              available through the getnchannels() method
103    _nframes -- the number of audio frames
104              available through the getnframes() method
105    _sampwidth -- the number of bytes per audio sample
106              available through the getsampwidth() method
107    _framerate -- the sampling frequency
108              available through the getframerate() method
109    _comptype -- the AIFF-C compression type ('NONE' if AIFF)
110              available through the getcomptype() method
111    _compname -- the human-readable AIFF-C compression type
112              available through the getcomptype() method
113    _soundpos -- the position in the audio stream
114              available through the tell() method, set through the
115              setpos() method
116
117    These variables are used internally only:
118    _fmt_chunk_read -- 1 iff the FMT chunk has been read
119    _data_seek_needed -- 1 iff positioned correctly in audio
120              file for readframes()
121    _data_chunk -- instantiation of a chunk class for the DATA chunk
122    _framesize -- size of one frame in the file
123    """
124
125    def initfp(self, file):
126        self._convert = None
127        self._soundpos = 0
128        self._file = Chunk(file, bigendian = 0)
129        if self._file.getname() != b'RIFF':
130            raise Error('file does not start with RIFF id')
131        if self._file.read(4) != b'WAVE':
132            raise Error('not a WAVE file')
133        self._fmt_chunk_read = 0
134        self._data_chunk = None
135        while 1:
136            self._data_seek_needed = 1
137            try:
138                chunk = Chunk(self._file, bigendian = 0)
139            except EOFError:
140                break
141            chunkname = chunk.getname()
142            if chunkname == b'fmt ':
143                self._read_fmt_chunk(chunk)
144                self._fmt_chunk_read = 1
145            elif chunkname == b'data':
146                if not self._fmt_chunk_read:
147                    raise Error('data chunk before fmt chunk')
148                self._data_chunk = chunk
149                self._nframes = chunk.chunksize // self._framesize
150                self._data_seek_needed = 0
151                break
152            chunk.skip()
153        if not self._fmt_chunk_read or not self._data_chunk:
154            raise Error('fmt chunk and/or data chunk missing')
155
156    def __init__(self, f):
157        self._i_opened_the_file = None
158        if isinstance(f, str):
159            f = builtins.open(f, 'rb')
160            self._i_opened_the_file = f
161        # else, assume it is an open file object already
162        try:
163            self.initfp(f)
164        except:
165            if self._i_opened_the_file:
166                f.close()
167            raise
168
169    def __del__(self):
170        self.close()
171
172    def __enter__(self):
173        return self
174
175    def __exit__(self, *args):
176        self.close()
177
178    #
179    # User visible methods.
180    #
181    def getfp(self):
182        return self._file
183
184    def rewind(self):
185        self._data_seek_needed = 1
186        self._soundpos = 0
187
188    def close(self):
189        self._file = None
190        file = self._i_opened_the_file
191        if file:
192            self._i_opened_the_file = None
193            file.close()
194
195    def tell(self):
196        return self._soundpos
197
198    def getnchannels(self):
199        return self._nchannels
200
201    def getnframes(self):
202        return self._nframes
203
204    def getsampwidth(self):
205        return self._sampwidth
206
207    def getframerate(self):
208        return self._framerate
209
210    def getcomptype(self):
211        return self._comptype
212
213    def getcompname(self):
214        return self._compname
215
216    def getparams(self):
217        return _wave_params(self.getnchannels(), self.getsampwidth(),
218                       self.getframerate(), self.getnframes(),
219                       self.getcomptype(), self.getcompname())
220
221    def getmarkers(self):
222        return None
223
224    def getmark(self, id):
225        raise Error('no marks')
226
227    def setpos(self, pos):
228        if pos < 0 or pos > self._nframes:
229            raise Error('position not in range')
230        self._soundpos = pos
231        self._data_seek_needed = 1
232
233    def readframes(self, nframes):
234        if self._data_seek_needed:
235            self._data_chunk.seek(0, 0)
236            pos = self._soundpos * self._framesize
237            if pos:
238                self._data_chunk.seek(pos, 0)
239            self._data_seek_needed = 0
240        if nframes == 0:
241            return b''
242        data = self._data_chunk.read(nframes * self._framesize)
243        if self._sampwidth != 1 and sys.byteorder == 'big':
244            data = audioop.byteswap(data, self._sampwidth)
245        if self._convert and data:
246            data = self._convert(data)
247        self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
248        return data
249
250    #
251    # Internal methods.
252    #
253
254    def _read_fmt_chunk(self, chunk):
255        try:
256            wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
257        except struct.error:
258            raise EOFError from None
259        if wFormatTag == WAVE_FORMAT_PCM:
260            try:
261                sampwidth = struct.unpack_from('<H', chunk.read(2))[0]
262            except struct.error:
263                raise EOFError from None
264            self._sampwidth = (sampwidth + 7) // 8
265            if not self._sampwidth:
266                raise Error('bad sample width')
267        else:
268            raise Error('unknown format: %r' % (wFormatTag,))
269        if not self._nchannels:
270            raise Error('bad # of channels')
271        self._framesize = self._nchannels * self._sampwidth
272        self._comptype = 'NONE'
273        self._compname = 'not compressed'
274
275class Wave_write:
276    """Variables used in this class:
277
278    These variables are user settable through appropriate methods
279    of this class:
280    _file -- the open file with methods write(), close(), tell(), seek()
281              set through the __init__() method
282    _comptype -- the AIFF-C compression type ('NONE' in AIFF)
283              set through the setcomptype() or setparams() method
284    _compname -- the human-readable AIFF-C compression type
285              set through the setcomptype() or setparams() method
286    _nchannels -- the number of audio channels
287              set through the setnchannels() or setparams() method
288    _sampwidth -- the number of bytes per audio sample
289              set through the setsampwidth() or setparams() method
290    _framerate -- the sampling frequency
291              set through the setframerate() or setparams() method
292    _nframes -- the number of audio frames written to the header
293              set through the setnframes() or setparams() method
294
295    These variables are used internally only:
296    _datalength -- the size of the audio samples written to the header
297    _nframeswritten -- the number of frames actually written
298    _datawritten -- the size of the audio samples actually written
299    """
300
301    def __init__(self, f):
302        self._i_opened_the_file = None
303        if isinstance(f, str):
304            f = builtins.open(f, 'wb')
305            self._i_opened_the_file = f
306        try:
307            self.initfp(f)
308        except:
309            if self._i_opened_the_file:
310                f.close()
311            raise
312
313    def initfp(self, file):
314        self._file = file
315        self._convert = None
316        self._nchannels = 0
317        self._sampwidth = 0
318        self._framerate = 0
319        self._nframes = 0
320        self._nframeswritten = 0
321        self._datawritten = 0
322        self._datalength = 0
323        self._headerwritten = False
324
325    def __del__(self):
326        self.close()
327
328    def __enter__(self):
329        return self
330
331    def __exit__(self, *args):
332        self.close()
333
334    #
335    # User visible methods.
336    #
337    def setnchannels(self, nchannels):
338        if self._datawritten:
339            raise Error('cannot change parameters after starting to write')
340        if nchannels < 1:
341            raise Error('bad # of channels')
342        self._nchannels = nchannels
343
344    def getnchannels(self):
345        if not self._nchannels:
346            raise Error('number of channels not set')
347        return self._nchannels
348
349    def setsampwidth(self, sampwidth):
350        if self._datawritten:
351            raise Error('cannot change parameters after starting to write')
352        if sampwidth < 1 or sampwidth > 4:
353            raise Error('bad sample width')
354        self._sampwidth = sampwidth
355
356    def getsampwidth(self):
357        if not self._sampwidth:
358            raise Error('sample width not set')
359        return self._sampwidth
360
361    def setframerate(self, framerate):
362        if self._datawritten:
363            raise Error('cannot change parameters after starting to write')
364        if framerate <= 0:
365            raise Error('bad frame rate')
366        self._framerate = int(round(framerate))
367
368    def getframerate(self):
369        if not self._framerate:
370            raise Error('frame rate not set')
371        return self._framerate
372
373    def setnframes(self, nframes):
374        if self._datawritten:
375            raise Error('cannot change parameters after starting to write')
376        self._nframes = nframes
377
378    def getnframes(self):
379        return self._nframeswritten
380
381    def setcomptype(self, comptype, compname):
382        if self._datawritten:
383            raise Error('cannot change parameters after starting to write')
384        if comptype not in ('NONE',):
385            raise Error('unsupported compression type')
386        self._comptype = comptype
387        self._compname = compname
388
389    def getcomptype(self):
390        return self._comptype
391
392    def getcompname(self):
393        return self._compname
394
395    def setparams(self, params):
396        nchannels, sampwidth, framerate, nframes, comptype, compname = params
397        if self._datawritten:
398            raise Error('cannot change parameters after starting to write')
399        self.setnchannels(nchannels)
400        self.setsampwidth(sampwidth)
401        self.setframerate(framerate)
402        self.setnframes(nframes)
403        self.setcomptype(comptype, compname)
404
405    def getparams(self):
406        if not self._nchannels or not self._sampwidth or not self._framerate:
407            raise Error('not all parameters set')
408        return _wave_params(self._nchannels, self._sampwidth, self._framerate,
409              self._nframes, self._comptype, self._compname)
410
411    def setmark(self, id, pos, name):
412        raise Error('setmark() not supported')
413
414    def getmark(self, id):
415        raise Error('no marks')
416
417    def getmarkers(self):
418        return None
419
420    def tell(self):
421        return self._nframeswritten
422
423    def writeframesraw(self, data):
424        if not isinstance(data, (bytes, bytearray)):
425            data = memoryview(data).cast('B')
426        self._ensure_header_written(len(data))
427        nframes = len(data) // (self._sampwidth * self._nchannels)
428        if self._convert:
429            data = self._convert(data)
430        if self._sampwidth != 1 and sys.byteorder == 'big':
431            data = audioop.byteswap(data, self._sampwidth)
432        self._file.write(data)
433        self._datawritten += len(data)
434        self._nframeswritten = self._nframeswritten + nframes
435
436    def writeframes(self, data):
437        self.writeframesraw(data)
438        if self._datalength != self._datawritten:
439            self._patchheader()
440
441    def close(self):
442        try:
443            if self._file:
444                self._ensure_header_written(0)
445                if self._datalength != self._datawritten:
446                    self._patchheader()
447                self._file.flush()
448        finally:
449            self._file = None
450            file = self._i_opened_the_file
451            if file:
452                self._i_opened_the_file = None
453                file.close()
454
455    #
456    # Internal methods.
457    #
458
459    def _ensure_header_written(self, datasize):
460        if not self._headerwritten:
461            if not self._nchannels:
462                raise Error('# channels not specified')
463            if not self._sampwidth:
464                raise Error('sample width not specified')
465            if not self._framerate:
466                raise Error('sampling rate not specified')
467            self._write_header(datasize)
468
469    def _write_header(self, initlength):
470        assert not self._headerwritten
471        self._file.write(b'RIFF')
472        if not self._nframes:
473            self._nframes = initlength // (self._nchannels * self._sampwidth)
474        self._datalength = self._nframes * self._nchannels * self._sampwidth
475        try:
476            self._form_length_pos = self._file.tell()
477        except (AttributeError, OSError):
478            self._form_length_pos = None
479        self._file.write(struct.pack('<L4s4sLHHLLHH4s',
480            36 + self._datalength, b'WAVE', b'fmt ', 16,
481            WAVE_FORMAT_PCM, self._nchannels, self._framerate,
482            self._nchannels * self._framerate * self._sampwidth,
483            self._nchannels * self._sampwidth,
484            self._sampwidth * 8, b'data'))
485        if self._form_length_pos is not None:
486            self._data_length_pos = self._file.tell()
487        self._file.write(struct.pack('<L', self._datalength))
488        self._headerwritten = True
489
490    def _patchheader(self):
491        assert self._headerwritten
492        if self._datawritten == self._datalength:
493            return
494        curpos = self._file.tell()
495        self._file.seek(self._form_length_pos, 0)
496        self._file.write(struct.pack('<L', 36 + self._datawritten))
497        self._file.seek(self._data_length_pos, 0)
498        self._file.write(struct.pack('<L', self._datawritten))
499        self._file.seek(curpos, 0)
500        self._datalength = self._datawritten
501
502def open(f, mode=None):
503    if mode is None:
504        if hasattr(f, 'mode'):
505            mode = f.mode
506        else:
507            mode = 'rb'
508    if mode in ('r', 'rb'):
509        return Wave_read(f)
510    elif mode in ('w', 'wb'):
511        return Wave_write(f)
512    else:
513        raise Error("mode must be 'r', 'rb', 'w', or 'wb'")
514