1"""Stuff to parse WAVE files.
2
3Usage.
4
5Reading WAVE files:
6      f = wave.open(file, 'r')
7where file is either the name of a file or an open file pointer.
8The open file pointer must have methods read(), seek(), and close().
9When the setpos() and rewind() methods are not used, the seek()
10method is not  necessary.
11
12This returns an instance of a class with the following public methods:
13      getnchannels()  -- returns number of audio channels (1 for
14                         mono, 2 for stereo)
15      getsampwidth()  -- returns sample width in bytes
16      getframerate()  -- returns sampling frequency
17      getnframes()    -- returns number of audio frames
18      getcomptype()   -- returns compression type ('NONE' for linear samples)
19      getcompname()   -- returns human-readable version of
20                         compression type ('not compressed' linear samples)
21      getparams()     -- returns a namedtuple consisting of all of the
22                         above in the above order
23      getmarkers()    -- returns None (for compatibility with the
24                         aifc module)
25      getmark(id)     -- raises an error since the mark does not
26                         exist (for compatibility with the aifc module)
27      readframes(n)   -- returns at most n frames of audio
28      rewind()        -- rewind to the beginning of the audio stream
29      setpos(pos)     -- seek to the specified position
30      tell()          -- return the current position
31      close()         -- close the instance (make it unusable)
32The position returned by tell() and the position given to setpos()
33are compatible and have nothing to do with the actual position in the
34file.
35The close() method is called automatically when the class instance
36is destroyed.
37
38Writing WAVE files:
39      f = wave.open(file, 'w')
40where file is either the name of a file or an open file pointer.
41The open file pointer must have methods write(), tell(), seek(), and
42close().
43
44This returns an instance of a class with the following public methods:
45      setnchannels(n) -- set the number of channels
46      setsampwidth(n) -- set the sample width
47      setframerate(n) -- set the frame rate
48      setnframes(n)   -- set the number of frames
49      setcomptype(type, name)
50                      -- set the compression type and the
51                         human-readable compression type
52      setparams(tuple)
53                      -- set all parameters at once
54      tell()          -- return current position in output file
55      writeframesraw(data)
56                      -- write audio frames without patching up the
57                         file header
58      writeframes(data)
59                      -- write audio frames and patch up the file header
60      close()         -- patch up the file header and close the
61                         output file
62You should set the parameters before the first writeframesraw or
63writeframes.  The total number of frames does not need to be set,
64but when it is set to the correct value, the header does not have to
65be patched up.
66It is best to first set all parameters, perhaps possibly the
67compression type, and then write audio frames using writeframesraw.
68When all frames have been written, either call writeframes(b'') or
69close() to patch up the sizes in the header.
70The close() method is called automatically when the class instance
71is destroyed.
72"""
73
74import builtins
75
76__all__ = ["open", "openfp", "Error", "Wave_read", "Wave_write"]
77
78class Error(Exception):
79    pass
80
81WAVE_FORMAT_PCM = 0x0001
82
83_array_fmts = None, 'b', 'h', None, 'i'
84
85import audioop
86import struct
87import sys
88from chunk import Chunk
89from collections import namedtuple
90import warnings
91
92_wave_params = namedtuple('_wave_params',
93                     'nchannels sampwidth framerate nframes comptype compname')
94
95class Wave_read:
96    """Variables used in this class:
97
98    These variables are available to the user though appropriate
99    methods of this class:
100    _file -- the open file with methods read(), close(), and seek()
101              set through the __init__() method
102    _nchannels -- the number of audio channels
103              available through the getnchannels() method
104    _nframes -- the number of audio frames
105              available through the getnframes() method
106    _sampwidth -- the number of bytes per audio sample
107              available through the getsampwidth() method
108    _framerate -- the sampling frequency
109              available through the getframerate() method
110    _comptype -- the AIFF-C compression type ('NONE' if AIFF)
111              available through the getcomptype() method
112    _compname -- the human-readable AIFF-C compression type
113              available through the getcomptype() method
114    _soundpos -- the position in the audio stream
115              available through the tell() method, set through the
116              setpos() method
117
118    These variables are used internally only:
119    _fmt_chunk_read -- 1 iff the FMT chunk has been read
120    _data_seek_needed -- 1 iff positioned correctly in audio
121              file for readframes()
122    _data_chunk -- instantiation of a chunk class for the DATA chunk
123    _framesize -- size of one frame in the file
124    """
125
126    def initfp(self, file):
127        self._convert = None
128        self._soundpos = 0
129        self._file = Chunk(file, bigendian = 0)
130        if self._file.getname() != b'RIFF':
131            raise Error('file does not start with RIFF id')
132        if self._file.read(4) != b'WAVE':
133            raise Error('not a WAVE file')
134        self._fmt_chunk_read = 0
135        self._data_chunk = None
136        while 1:
137            self._data_seek_needed = 1
138            try:
139                chunk = Chunk(self._file, bigendian = 0)
140            except EOFError:
141                break
142            chunkname = chunk.getname()
143            if chunkname == b'fmt ':
144                self._read_fmt_chunk(chunk)
145                self._fmt_chunk_read = 1
146            elif chunkname == b'data':
147                if not self._fmt_chunk_read:
148                    raise Error('data chunk before fmt chunk')
149                self._data_chunk = chunk
150                self._nframes = chunk.chunksize // self._framesize
151                self._data_seek_needed = 0
152                break
153            chunk.skip()
154        if not self._fmt_chunk_read or not self._data_chunk:
155            raise Error('fmt chunk and/or data chunk missing')
156
157    def __init__(self, f):
158        self._i_opened_the_file = None
159        if isinstance(f, str):
160            f = builtins.open(f, 'rb')
161            self._i_opened_the_file = f
162        # else, assume it is an open file object already
163        try:
164            self.initfp(f)
165        except:
166            if self._i_opened_the_file:
167                f.close()
168            raise
169
170    def __del__(self):
171        self.close()
172
173    def __enter__(self):
174        return self
175
176    def __exit__(self, *args):
177        self.close()
178
179    #
180    # User visible methods.
181    #
182    def getfp(self):
183        return self._file
184
185    def rewind(self):
186        self._data_seek_needed = 1
187        self._soundpos = 0
188
189    def close(self):
190        self._file = None
191        file = self._i_opened_the_file
192        if file:
193            self._i_opened_the_file = None
194            file.close()
195
196    def tell(self):
197        return self._soundpos
198
199    def getnchannels(self):
200        return self._nchannels
201
202    def getnframes(self):
203        return self._nframes
204
205    def getsampwidth(self):
206        return self._sampwidth
207
208    def getframerate(self):
209        return self._framerate
210
211    def getcomptype(self):
212        return self._comptype
213
214    def getcompname(self):
215        return self._compname
216
217    def getparams(self):
218        return _wave_params(self.getnchannels(), self.getsampwidth(),
219                       self.getframerate(), self.getnframes(),
220                       self.getcomptype(), self.getcompname())
221
222    def getmarkers(self):
223        return None
224
225    def getmark(self, id):
226        raise Error('no marks')
227
228    def setpos(self, pos):
229        if pos < 0 or pos > self._nframes:
230            raise Error('position not in range')
231        self._soundpos = pos
232        self._data_seek_needed = 1
233
234    def readframes(self, nframes):
235        if self._data_seek_needed:
236            self._data_chunk.seek(0, 0)
237            pos = self._soundpos * self._framesize
238            if pos:
239                self._data_chunk.seek(pos, 0)
240            self._data_seek_needed = 0
241        if nframes == 0:
242            return b''
243        data = self._data_chunk.read(nframes * self._framesize)
244        if self._sampwidth != 1 and sys.byteorder == 'big':
245            data = audioop.byteswap(data, self._sampwidth)
246        if self._convert and data:
247            data = self._convert(data)
248        self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
249        return data
250
251    #
252    # Internal methods.
253    #
254
255    def _read_fmt_chunk(self, chunk):
256        try:
257            wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
258        except struct.error:
259            raise EOFError from None
260        if wFormatTag == WAVE_FORMAT_PCM:
261            try:
262                sampwidth = struct.unpack_from('<H', chunk.read(2))[0]
263            except struct.error:
264                raise EOFError from None
265            self._sampwidth = (sampwidth + 7) // 8
266            if not self._sampwidth:
267                raise Error('bad sample width')
268        else:
269            raise Error('unknown format: %r' % (wFormatTag,))
270        if not self._nchannels:
271            raise Error('bad # of channels')
272        self._framesize = self._nchannels * self._sampwidth
273        self._comptype = 'NONE'
274        self._compname = 'not compressed'
275
276class Wave_write:
277    """Variables used in this class:
278
279    These variables are user settable through appropriate methods
280    of this class:
281    _file -- the open file with methods write(), close(), tell(), seek()
282              set through the __init__() method
283    _comptype -- the AIFF-C compression type ('NONE' in AIFF)
284              set through the setcomptype() or setparams() method
285    _compname -- the human-readable AIFF-C compression type
286              set through the setcomptype() or setparams() method
287    _nchannels -- the number of audio channels
288              set through the setnchannels() or setparams() method
289    _sampwidth -- the number of bytes per audio sample
290              set through the setsampwidth() or setparams() method
291    _framerate -- the sampling frequency
292              set through the setframerate() or setparams() method
293    _nframes -- the number of audio frames written to the header
294              set through the setnframes() or setparams() method
295
296    These variables are used internally only:
297    _datalength -- the size of the audio samples written to the header
298    _nframeswritten -- the number of frames actually written
299    _datawritten -- the size of the audio samples actually written
300    """
301
302    def __init__(self, f):
303        self._i_opened_the_file = None
304        if isinstance(f, str):
305            f = builtins.open(f, 'wb')
306            self._i_opened_the_file = f
307        try:
308            self.initfp(f)
309        except:
310            if self._i_opened_the_file:
311                f.close()
312            raise
313
314    def initfp(self, file):
315        self._file = file
316        self._convert = None
317        self._nchannels = 0
318        self._sampwidth = 0
319        self._framerate = 0
320        self._nframes = 0
321        self._nframeswritten = 0
322        self._datawritten = 0
323        self._datalength = 0
324        self._headerwritten = False
325
326    def __del__(self):
327        self.close()
328
329    def __enter__(self):
330        return self
331
332    def __exit__(self, *args):
333        self.close()
334
335    #
336    # User visible methods.
337    #
338    def setnchannels(self, nchannels):
339        if self._datawritten:
340            raise Error('cannot change parameters after starting to write')
341        if nchannels < 1:
342            raise Error('bad # of channels')
343        self._nchannels = nchannels
344
345    def getnchannels(self):
346        if not self._nchannels:
347            raise Error('number of channels not set')
348        return self._nchannels
349
350    def setsampwidth(self, sampwidth):
351        if self._datawritten:
352            raise Error('cannot change parameters after starting to write')
353        if sampwidth < 1 or sampwidth > 4:
354            raise Error('bad sample width')
355        self._sampwidth = sampwidth
356
357    def getsampwidth(self):
358        if not self._sampwidth:
359            raise Error('sample width not set')
360        return self._sampwidth
361
362    def setframerate(self, framerate):
363        if self._datawritten:
364            raise Error('cannot change parameters after starting to write')
365        if framerate <= 0:
366            raise Error('bad frame rate')
367        self._framerate = int(round(framerate))
368
369    def getframerate(self):
370        if not self._framerate:
371            raise Error('frame rate not set')
372        return self._framerate
373
374    def setnframes(self, nframes):
375        if self._datawritten:
376            raise Error('cannot change parameters after starting to write')
377        self._nframes = nframes
378
379    def getnframes(self):
380        return self._nframeswritten
381
382    def setcomptype(self, comptype, compname):
383        if self._datawritten:
384            raise Error('cannot change parameters after starting to write')
385        if comptype not in ('NONE',):
386            raise Error('unsupported compression type')
387        self._comptype = comptype
388        self._compname = compname
389
390    def getcomptype(self):
391        return self._comptype
392
393    def getcompname(self):
394        return self._compname
395
396    def setparams(self, params):
397        nchannels, sampwidth, framerate, nframes, comptype, compname = params
398        if self._datawritten:
399            raise Error('cannot change parameters after starting to write')
400        self.setnchannels(nchannels)
401        self.setsampwidth(sampwidth)
402        self.setframerate(framerate)
403        self.setnframes(nframes)
404        self.setcomptype(comptype, compname)
405
406    def getparams(self):
407        if not self._nchannels or not self._sampwidth or not self._framerate:
408            raise Error('not all parameters set')
409        return _wave_params(self._nchannels, self._sampwidth, self._framerate,
410              self._nframes, self._comptype, self._compname)
411
412    def setmark(self, id, pos, name):
413        raise Error('setmark() not supported')
414
415    def getmark(self, id):
416        raise Error('no marks')
417
418    def getmarkers(self):
419        return None
420
421    def tell(self):
422        return self._nframeswritten
423
424    def writeframesraw(self, data):
425        if not isinstance(data, (bytes, bytearray)):
426            data = memoryview(data).cast('B')
427        self._ensure_header_written(len(data))
428        nframes = len(data) // (self._sampwidth * self._nchannels)
429        if self._convert:
430            data = self._convert(data)
431        if self._sampwidth != 1 and sys.byteorder == 'big':
432            data = audioop.byteswap(data, self._sampwidth)
433        self._file.write(data)
434        self._datawritten += len(data)
435        self._nframeswritten = self._nframeswritten + nframes
436
437    def writeframes(self, data):
438        self.writeframesraw(data)
439        if self._datalength != self._datawritten:
440            self._patchheader()
441
442    def close(self):
443        try:
444            if self._file:
445                self._ensure_header_written(0)
446                if self._datalength != self._datawritten:
447                    self._patchheader()
448                self._file.flush()
449        finally:
450            self._file = None
451            file = self._i_opened_the_file
452            if file:
453                self._i_opened_the_file = None
454                file.close()
455
456    #
457    # Internal methods.
458    #
459
460    def _ensure_header_written(self, datasize):
461        if not self._headerwritten:
462            if not self._nchannels:
463                raise Error('# channels not specified')
464            if not self._sampwidth:
465                raise Error('sample width not specified')
466            if not self._framerate:
467                raise Error('sampling rate not specified')
468            self._write_header(datasize)
469
470    def _write_header(self, initlength):
471        assert not self._headerwritten
472        self._file.write(b'RIFF')
473        if not self._nframes:
474            self._nframes = initlength // (self._nchannels * self._sampwidth)
475        self._datalength = self._nframes * self._nchannels * self._sampwidth
476        try:
477            self._form_length_pos = self._file.tell()
478        except (AttributeError, OSError):
479            self._form_length_pos = None
480        self._file.write(struct.pack('<L4s4sLHHLLHH4s',
481            36 + self._datalength, b'WAVE', b'fmt ', 16,
482            WAVE_FORMAT_PCM, self._nchannels, self._framerate,
483            self._nchannels * self._framerate * self._sampwidth,
484            self._nchannels * self._sampwidth,
485            self._sampwidth * 8, b'data'))
486        if self._form_length_pos is not None:
487            self._data_length_pos = self._file.tell()
488        self._file.write(struct.pack('<L', self._datalength))
489        self._headerwritten = True
490
491    def _patchheader(self):
492        assert self._headerwritten
493        if self._datawritten == self._datalength:
494            return
495        curpos = self._file.tell()
496        self._file.seek(self._form_length_pos, 0)
497        self._file.write(struct.pack('<L', 36 + self._datawritten))
498        self._file.seek(self._data_length_pos, 0)
499        self._file.write(struct.pack('<L', self._datawritten))
500        self._file.seek(curpos, 0)
501        self._datalength = self._datawritten
502
503def open(f, mode=None):
504    if mode is None:
505        if hasattr(f, 'mode'):
506            mode = f.mode
507        else:
508            mode = 'rb'
509    if mode in ('r', 'rb'):
510        return Wave_read(f)
511    elif mode in ('w', 'wb'):
512        return Wave_write(f)
513    else:
514        raise Error("mode must be 'r', 'rb', 'w', or 'wb'")
515
516def openfp(f, mode=None):
517    warnings.warn("wave.openfp is deprecated since Python 3.7. "
518                  "Use wave.open instead.", DeprecationWarning, stacklevel=2)
519    return open(f, mode=mode)
520