1"""Stuff to parse WAVE files.
2
3Usage.
4
5Reading WAVE files:
6      f = wave.open(file, 'r')
7where file is either the name of a file or an open file pointer.
8The open file pointer must have methods read(), seek(), and close().
9When the setpos() and rewind() methods are not used, the seek()
10method is not  necessary.
11
12This returns an instance of a class with the following public methods:
13      getnchannels()  -- returns number of audio channels (1 for
14                         mono, 2 for stereo)
15      getsampwidth()  -- returns sample width in bytes
16      getframerate()  -- returns sampling frequency
17      getnframes()    -- returns number of audio frames
18      getcomptype()   -- returns compression type ('NONE' for linear samples)
19      getcompname()   -- returns human-readable version of
20                         compression type ('not compressed' linear samples)
21      getparams()     -- returns a tuple consisting of all of the
22                         above in the above order
23      getmarkers()    -- returns None (for compatibility with the
24                         aifc module)
25      getmark(id)     -- raises an error since the mark does not
26                         exist (for compatibility with the aifc module)
27      readframes(n)   -- returns at most n frames of audio
28      rewind()        -- rewind to the beginning of the audio stream
29      setpos(pos)     -- seek to the specified position
30      tell()          -- return the current position
31      close()         -- close the instance (make it unusable)
32The position returned by tell() and the position given to setpos()
33are compatible and have nothing to do with the actual position in the
34file.
35The close() method is called automatically when the class instance
36is destroyed.
37
38Writing WAVE files:
39      f = wave.open(file, 'w')
40where file is either the name of a file or an open file pointer.
41The open file pointer must have methods write(), tell(), seek(), and
42close().
43
44This returns an instance of a class with the following public methods:
45      setnchannels(n) -- set the number of channels
46      setsampwidth(n) -- set the sample width
47      setframerate(n) -- set the frame rate
48      setnframes(n)   -- set the number of frames
49      setcomptype(type, name)
50                      -- set the compression type and the
51                         human-readable compression type
52      setparams(tuple)
53                      -- set all parameters at once
54      tell()          -- return current position in output file
55      writeframesraw(data)
56                      -- write audio frames without pathing up the
57                         file header
58      writeframes(data)
59                      -- write audio frames and patch up the file header
60      close()         -- patch up the file header and close the
61                         output file
62You should set the parameters before the first writeframesraw or
63writeframes.  The total number of frames does not need to be set,
64but when it is set to the correct value, the header does not have to
65be patched up.
66It is best to first set all parameters, perhaps possibly the
67compression type, and then write audio frames using writeframesraw.
68When all frames have been written, either call writeframes('') or
69close() to patch up the sizes in the header.
70The close() method is called automatically when the class instance
71is destroyed.
72"""
73
74import __builtin__
75
76__all__ = ["open", "openfp", "Error"]
77
78class Error(Exception):
79    pass
80
81WAVE_FORMAT_PCM = 0x0001
82
83_array_fmts = None, 'b', 'h', None, 'l'
84
85# Determine endian-ness
86import struct
87if struct.pack("h", 1) == "\000\001":
88    big_endian = 1
89else:
90    big_endian = 0
91
92from chunk import Chunk
93
94class Wave_read:
95    """Variables used in this class:
96
97    These variables are available to the user though appropriate
98    methods of this class:
99    _file -- the open file with methods read(), close(), and seek()
100              set through the __init__() method
101    _nchannels -- the number of audio channels
102              available through the getnchannels() method
103    _nframes -- the number of audio frames
104              available through the getnframes() method
105    _sampwidth -- the number of bytes per audio sample
106              available through the getsampwidth() method
107    _framerate -- the sampling frequency
108              available through the getframerate() method
109    _comptype -- the AIFF-C compression type ('NONE' if AIFF)
110              available through the getcomptype() method
111    _compname -- the human-readable AIFF-C compression type
112              available through the getcomptype() method
113    _soundpos -- the position in the audio stream
114              available through the tell() method, set through the
115              setpos() method
116
117    These variables are used internally only:
118    _fmt_chunk_read -- 1 iff the FMT chunk has been read
119    _data_seek_needed -- 1 iff positioned correctly in audio
120              file for readframes()
121    _data_chunk -- instantiation of a chunk class for the DATA chunk
122    _framesize -- size of one frame in the file
123    """
124
125    def initfp(self, file):
126        self._convert = None
127        self._soundpos = 0
128        self._file = Chunk(file, bigendian = 0)
129        if self._file.getname() != 'RIFF':
130            raise Error, 'file does not start with RIFF id'
131        if self._file.read(4) != 'WAVE':
132            raise Error, 'not a WAVE file'
133        self._fmt_chunk_read = 0
134        self._data_chunk = None
135        while 1:
136            self._data_seek_needed = 1
137            try:
138                chunk = Chunk(self._file, bigendian = 0)
139            except EOFError:
140                break
141            chunkname = chunk.getname()
142            if chunkname == 'fmt ':
143                self._read_fmt_chunk(chunk)
144                self._fmt_chunk_read = 1
145            elif chunkname == 'data':
146                if not self._fmt_chunk_read:
147                    raise Error, 'data chunk before fmt chunk'
148                self._data_chunk = chunk
149                self._nframes = chunk.chunksize // self._framesize
150                self._data_seek_needed = 0
151                break
152            chunk.skip()
153        if not self._fmt_chunk_read or not self._data_chunk:
154            raise Error, 'fmt chunk and/or data chunk missing'
155
156    def __init__(self, f):
157        self._i_opened_the_file = None
158        if isinstance(f, basestring):
159            f = __builtin__.open(f, 'rb')
160            self._i_opened_the_file = f
161        # else, assume it is an open file object already
162        try:
163            self.initfp(f)
164        except:
165            if self._i_opened_the_file:
166                f.close()
167            raise
168
169    def __del__(self):
170        self.close()
171    #
172    # User visible methods.
173    #
174    def getfp(self):
175        return self._file
176
177    def rewind(self):
178        self._data_seek_needed = 1
179        self._soundpos = 0
180
181    def close(self):
182        if self._i_opened_the_file:
183            self._i_opened_the_file.close()
184            self._i_opened_the_file = None
185        self._file = None
186
187    def tell(self):
188        return self._soundpos
189
190    def getnchannels(self):
191        return self._nchannels
192
193    def getnframes(self):
194        return self._nframes
195
196    def getsampwidth(self):
197        return self._sampwidth
198
199    def getframerate(self):
200        return self._framerate
201
202    def getcomptype(self):
203        return self._comptype
204
205    def getcompname(self):
206        return self._compname
207
208    def getparams(self):
209        return self.getnchannels(), self.getsampwidth(), \
210               self.getframerate(), self.getnframes(), \
211               self.getcomptype(), self.getcompname()
212
213    def getmarkers(self):
214        return None
215
216    def getmark(self, id):
217        raise Error, 'no marks'
218
219    def setpos(self, pos):
220        if pos < 0 or pos > self._nframes:
221            raise Error, 'position not in range'
222        self._soundpos = pos
223        self._data_seek_needed = 1
224
225    def readframes(self, nframes):
226        if self._data_seek_needed:
227            self._data_chunk.seek(0, 0)
228            pos = self._soundpos * self._framesize
229            if pos:
230                self._data_chunk.seek(pos, 0)
231            self._data_seek_needed = 0
232        if nframes == 0:
233            return ''
234        if self._sampwidth > 1 and big_endian:
235            # unfortunately the fromfile() method does not take
236            # something that only looks like a file object, so
237            # we have to reach into the innards of the chunk object
238            import array
239            chunk = self._data_chunk
240            data = array.array(_array_fmts[self._sampwidth])
241            nitems = nframes * self._nchannels
242            if nitems * self._sampwidth > chunk.chunksize - chunk.size_read:
243                nitems = (chunk.chunksize - chunk.size_read) / self._sampwidth
244            data.fromfile(chunk.file.file, nitems)
245            # "tell" data chunk how much was read
246            chunk.size_read = chunk.size_read + nitems * self._sampwidth
247            # do the same for the outermost chunk
248            chunk = chunk.file
249            chunk.size_read = chunk.size_read + nitems * self._sampwidth
250            data.byteswap()
251            data = data.tostring()
252        else:
253            data = self._data_chunk.read(nframes * self._framesize)
254        if self._convert and data:
255            data = self._convert(data)
256        self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
257        return data
258
259    #
260    # Internal methods.
261    #
262
263    def _read_fmt_chunk(self, chunk):
264        wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack('<hhllh', chunk.read(14))
265        if wFormatTag == WAVE_FORMAT_PCM:
266            sampwidth = struct.unpack('<h', chunk.read(2))[0]
267            self._sampwidth = (sampwidth + 7) // 8
268        else:
269            raise Error, 'unknown format: %r' % (wFormatTag,)
270        self._framesize = self._nchannels * self._sampwidth
271        self._comptype = 'NONE'
272        self._compname = 'not compressed'
273
274class Wave_write:
275    """Variables used in this class:
276
277    These variables are user settable through appropriate methods
278    of this class:
279    _file -- the open file with methods write(), close(), tell(), seek()
280              set through the __init__() method
281    _comptype -- the AIFF-C compression type ('NONE' in AIFF)
282              set through the setcomptype() or setparams() method
283    _compname -- the human-readable AIFF-C compression type
284              set through the setcomptype() or setparams() method
285    _nchannels -- the number of audio channels
286              set through the setnchannels() or setparams() method
287    _sampwidth -- the number of bytes per audio sample
288              set through the setsampwidth() or setparams() method
289    _framerate -- the sampling frequency
290              set through the setframerate() or setparams() method
291    _nframes -- the number of audio frames written to the header
292              set through the setnframes() or setparams() method
293
294    These variables are used internally only:
295    _datalength -- the size of the audio samples written to the header
296    _nframeswritten -- the number of frames actually written
297    _datawritten -- the size of the audio samples actually written
298    """
299
300    def __init__(self, f):
301        self._i_opened_the_file = None
302        if isinstance(f, basestring):
303            f = __builtin__.open(f, 'wb')
304            self._i_opened_the_file = f
305        try:
306            self.initfp(f)
307        except:
308            if self._i_opened_the_file:
309                f.close()
310            raise
311
312    def initfp(self, file):
313        self._file = file
314        self._convert = None
315        self._nchannels = 0
316        self._sampwidth = 0
317        self._framerate = 0
318        self._nframes = 0
319        self._nframeswritten = 0
320        self._datawritten = 0
321        self._datalength = 0
322        self._headerwritten = False
323
324    def __del__(self):
325        self.close()
326
327    #
328    # User visible methods.
329    #
330    def setnchannels(self, nchannels):
331        if self._datawritten:
332            raise Error, 'cannot change parameters after starting to write'
333        if nchannels < 1:
334            raise Error, 'bad # of channels'
335        self._nchannels = nchannels
336
337    def getnchannels(self):
338        if not self._nchannels:
339            raise Error, 'number of channels not set'
340        return self._nchannels
341
342    def setsampwidth(self, sampwidth):
343        if self._datawritten:
344            raise Error, 'cannot change parameters after starting to write'
345        if sampwidth < 1 or sampwidth > 4:
346            raise Error, 'bad sample width'
347        self._sampwidth = sampwidth
348
349    def getsampwidth(self):
350        if not self._sampwidth:
351            raise Error, 'sample width not set'
352        return self._sampwidth
353
354    def setframerate(self, framerate):
355        if self._datawritten:
356            raise Error, 'cannot change parameters after starting to write'
357        if framerate <= 0:
358            raise Error, 'bad frame rate'
359        self._framerate = framerate
360
361    def getframerate(self):
362        if not self._framerate:
363            raise Error, 'frame rate not set'
364        return self._framerate
365
366    def setnframes(self, nframes):
367        if self._datawritten:
368            raise Error, 'cannot change parameters after starting to write'
369        self._nframes = nframes
370
371    def getnframes(self):
372        return self._nframeswritten
373
374    def setcomptype(self, comptype, compname):
375        if self._datawritten:
376            raise Error, 'cannot change parameters after starting to write'
377        if comptype not in ('NONE',):
378            raise Error, 'unsupported compression type'
379        self._comptype = comptype
380        self._compname = compname
381
382    def getcomptype(self):
383        return self._comptype
384
385    def getcompname(self):
386        return self._compname
387
388    def setparams(self, params):
389        nchannels, sampwidth, framerate, nframes, comptype, compname = params
390        if self._datawritten:
391            raise Error, 'cannot change parameters after starting to write'
392        self.setnchannels(nchannels)
393        self.setsampwidth(sampwidth)
394        self.setframerate(framerate)
395        self.setnframes(nframes)
396        self.setcomptype(comptype, compname)
397
398    def getparams(self):
399        if not self._nchannels or not self._sampwidth or not self._framerate:
400            raise Error, 'not all parameters set'
401        return self._nchannels, self._sampwidth, self._framerate, \
402              self._nframes, self._comptype, self._compname
403
404    def setmark(self, id, pos, name):
405        raise Error, 'setmark() not supported'
406
407    def getmark(self, id):
408        raise Error, 'no marks'
409
410    def getmarkers(self):
411        return None
412
413    def tell(self):
414        return self._nframeswritten
415
416    def writeframesraw(self, data):
417        self._ensure_header_written(len(data))
418        nframes = len(data) // (self._sampwidth * self._nchannels)
419        if self._convert:
420            data = self._convert(data)
421        if self._sampwidth > 1 and big_endian:
422            import array
423            data = array.array(_array_fmts[self._sampwidth], data)
424            data.byteswap()
425            data.tofile(self._file)
426            self._datawritten = self._datawritten + len(data) * self._sampwidth
427        else:
428            self._file.write(data)
429            self._datawritten = self._datawritten + len(data)
430        self._nframeswritten = self._nframeswritten + nframes
431
432    def writeframes(self, data):
433        self.writeframesraw(data)
434        if self._datalength != self._datawritten:
435            self._patchheader()
436
437    def close(self):
438        if self._file:
439            self._ensure_header_written(0)
440            if self._datalength != self._datawritten:
441                self._patchheader()
442            self._file.flush()
443            self._file = None
444        if self._i_opened_the_file:
445            self._i_opened_the_file.close()
446            self._i_opened_the_file = None
447
448    #
449    # Internal methods.
450    #
451
452    def _ensure_header_written(self, datasize):
453        if not self._headerwritten:
454            if not self._nchannels:
455                raise Error, '# channels not specified'
456            if not self._sampwidth:
457                raise Error, 'sample width not specified'
458            if not self._framerate:
459                raise Error, 'sampling rate not specified'
460            self._write_header(datasize)
461
462    def _write_header(self, initlength):
463        assert not self._headerwritten
464        self._file.write('RIFF')
465        if not self._nframes:
466            self._nframes = initlength / (self._nchannels * self._sampwidth)
467        self._datalength = self._nframes * self._nchannels * self._sampwidth
468        self._form_length_pos = self._file.tell()
469        self._file.write(struct.pack('<l4s4slhhllhh4s',
470            36 + self._datalength, 'WAVE', 'fmt ', 16,
471            WAVE_FORMAT_PCM, self._nchannels, self._framerate,
472            self._nchannels * self._framerate * self._sampwidth,
473            self._nchannels * self._sampwidth,
474            self._sampwidth * 8, 'data'))
475        self._data_length_pos = self._file.tell()
476        self._file.write(struct.pack('<l', self._datalength))
477        self._headerwritten = True
478
479    def _patchheader(self):
480        assert self._headerwritten
481        if self._datawritten == self._datalength:
482            return
483        curpos = self._file.tell()
484        self._file.seek(self._form_length_pos, 0)
485        self._file.write(struct.pack('<l', 36 + self._datawritten))
486        self._file.seek(self._data_length_pos, 0)
487        self._file.write(struct.pack('<l', self._datawritten))
488        self._file.seek(curpos, 0)
489        self._datalength = self._datawritten
490
491def open(f, mode=None):
492    if mode is None:
493        if hasattr(f, 'mode'):
494            mode = f.mode
495        else:
496            mode = 'rb'
497    if mode in ('r', 'rb'):
498        return Wave_read(f)
499    elif mode in ('w', 'wb'):
500        return Wave_write(f)
501    else:
502        raise Error, "mode must be 'r', 'rb', 'w', or 'wb'"
503
504openfp = open # B/W compatibility
505