1#!/usr/bin/python
2
3# Audio Tools, a module and set of tools for manipulating audio data
4# Copyright (C) 2007-2014  Brian Langenberger
5
6# This program is free software; you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation; either version 2 of the License, or
9# (at your option) any later version.
10
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15
16# You should have received a copy of the GNU General Public License
17# along with this program; if not, write to the Free Software
18# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
19
20from audiotools.bitstream import BitstreamWriter
21from audiotools.bitstream import BitstreamRecorder
22from audiotools.bitstream import format_size
23from audiotools import BufferedPCMReader
24from hashlib import md5
25
26# sub block IDs
27WV_DUMMY = 0x0
28WV_WAVE_HEADER = 0x1
29WV_WAVE_FOOTER = 0x2
30WV_TERMS = 0x2
31WV_WEIGHTS = 0x3
32WV_SAMPLES = 0x4
33WV_ENTROPY = 0x5
34WV_SAMPLE_RATE = 0x7
35WV_INT32_INFO = 0x9
36WV_BITSTREAM = 0xA
37WV_CHANNEL_INFO = 0xD
38WV_MD5 = 0x6
39
40
41class Counter(object):
42    def __init__(self, initial_value=0):
43        self.__value__ = int(initial_value)
44
45    def __int__(self):
46        return self.__value__
47
48    def add(self, byte):
49        self.__value__ += 1
50
51
52class EncoderContext(object):
53    def __init__(self, pcmreader, block_parameters,
54                 wave_header=None, wave_footer=None):
55        self.pcmreader = pcmreader
56        self.block_parameters = block_parameters
57        self.total_frames = 0
58        self.md5sum = md5()
59        self.first_block_written = False
60        self.wave_header = wave_header
61        self.wave_footer = wave_footer
62        self.wave_header_start = None
63
64
65def write_wave_header(writer, pcmreader, total_frames, wave_footer_len):
66    avg_bytes_per_second = (pcmreader.sample_rate *
67                            pcmreader.channels *
68                            (pcmreader.bits_per_sample // 8))
69    block_align = (pcmreader.channels *
70                   (pcmreader.bits_per_sample // 8))
71
72    total_size = 4 * 3   # 'RIFF' + size + 'WAVE'
73
74    total_size += 4 * 2  # 'fmt ' + size
75    if (pcmreader.channels <= 2) and (pcmreader.bits_per_sample <= 16):
76        # classic fmt chunk
77        fmt = "16u 16u 32u 32u 16u 16u"
78        fmt_fields = (1,   # compression code
79                      pcmreader.channels,
80                      pcmreader.sample_rate,
81                      avg_bytes_per_second,
82                      block_align,
83                      pcmreader.bits_per_sample)
84
85    else:
86        # extended fmt chunk
87        fmt = "16u 16u 32u 32u 16u 16u" + "16u 16u 32u 16b"
88        fmt_fields = (0xFFFE,   # compression code
89                      pcmreader.channels,
90                      pcmreader.sample_rate,
91                      avg_bytes_per_second,
92                      block_align,
93                      pcmreader.bits_per_sample,
94                      22,       # CB size
95                      pcmreader.bits_per_sample,
96                      int(pcmreader.channel_mask),
97                      b'\x01\x00\x00\x00\x00\x00\x10\x00' +
98                      b'\x80\x00\x00\xaa\x00\x38\x9b\x71'  # sub format
99                      )
100    total_size += format_size(fmt) // 8
101
102    total_size += 4 * 2  # 'data' + size
103    data_size = (total_frames *
104                 pcmreader.channels *
105                 (pcmreader.bits_per_sample // 8))
106    total_size += data_size
107
108    total_size += wave_footer_len
109
110    writer.build("4b 32u 4b 4b 32u" + fmt + "4b 32u",
111                 ((b'RIFF', total_size - 8, b'WAVE',
112                   b'fmt ', format_size(fmt) // 8) + fmt_fields +
113                  (b'data', data_size)))
114
115
116class CorrelationParameters(object):
117    """the parameters for a single correlation pass"""
118
119    def __init__(self, term, delta, weights, samples):
120        """term is a signed integer
121        delta is an unsigned integer
122        weights[c] is a weight value per channel c
123        samples[c][s] is sample "s" for channel "c"
124        """
125
126        # FIXME - sanity check these
127
128        self.term = term
129        self.delta = delta
130        self.weights = weights
131        self.samples = samples
132
133    def __repr__(self):
134        return "CorrelationParameters(%s, %s, %s, %s)" % \
135            (self.term, self.delta, self.weights, self.samples)
136
137    def update_weights(self, weights):
138        """given a weights[c] list of weight values per channel c
139        round-trips and sets this parameter's weights"""
140
141        assert(len(weights) == len(self.weights))
142        self.weights = [restore_weight(store_weight(w))
143                        for w in weights]
144
145    def update_samples(self, samples):
146        """given a samples[c][s] list of sample lists
147        round-trips and sets this parameter's samples"""
148
149        assert(len(samples) == len(samples))
150        self.samples = [[wv_exp2(wv_log2(s)) for s in c]
151                        for c in samples]
152
153
154class EncodingParameters(object):
155    """the encoding parameters for a single 1-2 channel block
156    multi-channel audio may have more than one set of these
157    """
158
159    def __init__(self, channel_count, correlation_passes):
160        """channel_count is 1 or 2
161        correlation_passes is in [0,1,2,5,10,16]
162        """
163
164        assert((channel_count == 1) or (channel_count == 2))
165        assert(correlation_passes in (0, 1, 2, 5, 10, 16))
166
167        self.channel_count = channel_count
168        self.correlation_passes = correlation_passes
169        self.entropy_variables = [[0, 0, 0], [0, 0, 0]]
170
171        self.__parameters_channel_count__ = 0
172        self.__correlation_parameters__ = None
173
174    def __repr__(self):
175        return "EncodingParameters(%s, %s, %s)" % \
176            (self.channel_count,
177             self.correlation_passes,
178             self.entropy_variables)
179
180    def correlation_parameters(self, false_stereo):
181        """given a "false_stereo" boolean
182        yields a CorrelationParameters object per correlation pass to be run
183
184        this may be less than the object's "correlation_passes" count
185        if "channel_count" is 1 or "false_stereo" is True
186        """
187
188        if (self.channel_count == 2) and (not false_stereo):
189            channel_count = 2
190        else:
191            channel_count = 1
192
193        if channel_count != self.__parameters_channel_count__:
194            if channel_count == 1:
195                if self.correlation_passes == 0:
196                    self.__correlation_parameters__ = []
197                elif self.correlation_passes == 1:
198                    self.__correlation_parameters__ = [
199                        CorrelationParameters(18, 2, [0], [[0] * 2])]
200                elif self.correlation_passes == 2:
201                    self.__correlation_parameters__ = [
202                        CorrelationParameters(17, 2, [0], [[0] * 2]),
203                        CorrelationParameters(18, 2, [0], [[0] * 2])]
204                elif self.correlation_passes in (5, 10, 16):
205                    self.__correlation_parameters__ = [
206                        CorrelationParameters(3, 2, [0], [[0] * 3]),
207                        CorrelationParameters(17, 2, [0], [[0] * 2]),
208                        CorrelationParameters(2, 2, [0], [[0] * 2]),
209                        CorrelationParameters(18, 2, [0], [[0] * 2]),
210                        CorrelationParameters(18, 2, [0], [[0] * 2])]
211                else:
212                    raise ValueError("invalid correlation pass count")
213            elif channel_count == 2:
214                if self.correlation_passes == 0:
215                    self.__correlation_parameters__ = []
216                elif self.correlation_passes == 1:
217                    self.__correlation_parameters__ = [
218                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
219                                                              [0] * 2])]
220                elif self.correlation_passes == 2:
221                    self.__correlation_parameters__ = [
222                        CorrelationParameters(17, 2, [0, 0], [[0] * 2,
223                                                              [0] * 2]),
224                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
225                                                              [0] * 2])]
226                elif self.correlation_passes == 5:
227                    self.__correlation_parameters__ = [
228                        CorrelationParameters(3, 2, [0, 0], [[0] * 3,
229                                                             [0] * 3]),
230                        CorrelationParameters(17, 2, [0, 0], [[0] * 2,
231                                                              [0] * 2]),
232                        CorrelationParameters(2, 2, [0, 0], [[0] * 2,
233                                                             [0] * 2]),
234                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
235                                                              [0] * 2]),
236                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
237                                                              [0] * 2])]
238                elif self.correlation_passes == 10:
239                    self.__correlation_parameters__ = [
240                        CorrelationParameters(4, 2, [0, 0], [[0] * 4,
241                                                             [0] * 4]),
242                        CorrelationParameters(17, 2, [0, 0], [[0] * 2,
243                                                              [0] * 2]),
244                        CorrelationParameters(-1, 2, [0, 0], [[0] * 1,
245                                                              [0] * 1]),
246                        CorrelationParameters(5, 2, [0, 0], [[0] * 5,
247                                                             [0] * 5]),
248                        CorrelationParameters(3, 2, [0, 0], [[0] * 3,
249                                                             [0] * 3]),
250                        CorrelationParameters(2, 2, [0, 0], [[0] * 2,
251                                                             [0] * 2]),
252                        CorrelationParameters(-2, 2, [0, 0], [[0] * 1,
253                                                              [0] * 1]),
254                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
255                                                              [0] * 2]),
256                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
257                                                              [0] * 2]),
258                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
259                                                              [0] * 2])]
260                elif self.correlation_passes == 16:
261                    self.__correlation_parameters__ = [
262                        CorrelationParameters(2, 2, [0, 0], [[0] * 2,
263                                                             [0] * 2]),
264                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
265                                                              [0] * 2]),
266                        CorrelationParameters(-1, 2, [0, 0], [[0] * 1,
267                                                              [0] * 1]),
268                        CorrelationParameters(8, 2, [0, 0], [[0] * 8,
269                                                             [0] * 8]),
270                        CorrelationParameters(6, 2, [0, 0], [[0] * 6,
271                                                             [0] * 6]),
272                        CorrelationParameters(3, 2, [0, 0], [[0] * 3,
273                                                             [0] * 3]),
274                        CorrelationParameters(5, 2, [0, 0], [[0] * 5,
275                                                             [0] * 5]),
276                        CorrelationParameters(7, 2, [0, 0], [[0] * 7,
277                                                             [0] * 7]),
278                        CorrelationParameters(4, 2, [0, 0], [[0] * 4,
279                                                             [0] * 4]),
280                        CorrelationParameters(2, 2, [0, 0], [[0] * 2,
281                                                             [0] * 2]),
282                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
283                                                              [0] * 2]),
284                        CorrelationParameters(-2, 2, [0, 0], [[0] * 1,
285                                                              [0] * 1]),
286                        CorrelationParameters(3, 2, [0, 0], [[0] * 3,
287                                                             [0] * 3]),
288                        CorrelationParameters(2, 2, [0, 0], [[0] * 2,
289                                                             [0] * 2]),
290                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
291                                                              [0] * 2]),
292                        CorrelationParameters(18, 2, [0, 0], [[0] * 2,
293                                                              [0] * 2])]
294                else:
295                    raise ValueError("invalid correlation pass count")
296
297        for parameters in self.__correlation_parameters__:
298            yield parameters
299
300
301def block_parameters(channel_count, channel_mask, correlation_passes):
302    if channel_count == 1:
303        return [EncodingParameters(1, correlation_passes)]
304    elif channel_count == 2:
305        return [EncodingParameters(2, correlation_passes)]
306    elif (channel_count == 3) and (channel_mask == 0x7):
307        # front left, front right, front center
308        return [EncodingParameters(2, correlation_passes),
309                EncodingParameters(1, correlation_passes)]
310    elif (channel_count == 4) and (channel_mask == 0x33):
311        # front left, front right, back left, back right
312        return [EncodingParameters(2, correlation_passes),
313                EncodingParameters(2, correlation_passes)]
314    elif (channel_count == 4) and (channel_mask == 0x107):
315        # front left, front right, front center, back center
316        return [EncodingParameters(2, correlation_passes),
317                EncodingParameters(1, correlation_passes),
318                EncodingParameters(1, correlation_passes)]
319    elif (channel_count == 5) and (channel_mask == 0x37):
320        # front left, front right, front center, back left, back right
321        return [EncodingParameters(2, correlation_passes),
322                EncodingParameters(1, correlation_passes),
323                EncodingParameters(2, correlation_passes)]
324    elif (channel_count == 6) and (channel_mask == 0x3F):
325        # front left, front right, front center, LFE, back left, back right
326        return [EncodingParameters(2, correlation_passes),
327                EncodingParameters(1, correlation_passes),
328                EncodingParameters(1, correlation_passes),
329                EncodingParameters(2, correlation_passes)]
330    else:
331        return [EncodingParameters(1, correlation_passes)
332                for c in range(channel_count)]
333
334
335def encode_wavpack(filename,
336                   pcmreader,
337                   block_size,
338                   total_pcm_frames=0,
339                   false_stereo=False,
340                   wasted_bits=False,
341                   joint_stereo=False,
342                   correlation_passes=0,
343                   wave_header=None,
344                   wave_footer=None):
345    pcmreader = BufferedPCMReader(pcmreader)
346    writer = BitstreamWriter(open(filename, "wb"), True)
347    context = EncoderContext(pcmreader,
348                             block_parameters(pcmreader.channels,
349                                              pcmreader.channel_mask,
350                                              correlation_passes),
351                             wave_header,
352                             wave_footer)
353    total_pcm_frames_positions = []
354
355    block_index = 0
356
357    # walk through PCM reader's FrameLists
358    frame = pcmreader.read(block_size)
359    while len(frame) > 0:
360        context.total_frames += frame.frames
361        context.md5sum.update(
362            frame.to_bytes(False, pcmreader.bits_per_sample >= 16))
363
364        c = 0
365        for parameters in context.block_parameters:
366            if parameters.channel_count == 1:
367                channel_data = [list(frame.channel(c))]
368            else:
369                channel_data = [list(frame.channel(c)),
370                                list(frame.channel(c + 1))]
371
372            first_block = parameters is context.block_parameters[0]
373            last_block = parameters is context.block_parameters[-1]
374
375            total_pcm_frames_pos = write_block(writer,
376                                               context,
377                                               channel_data,
378                                               total_pcm_frames,
379                                               block_index,
380                                               first_block,
381                                               last_block,
382                                               parameters)
383
384            c += parameters.channel_count
385            if total_pcm_frames == 0:
386                total_pcm_frames_positions.append(total_pcm_frames_pos)
387
388        block_index += frame.frames
389        frame = pcmreader.read(block_size)
390
391    # write MD5 sum and optional Wave footer in final block
392    sub_block = BitstreamRecorder(1)
393    sub_blocks_size = Counter()
394
395    # write block header with placeholder sub blocks size
396    (sub_blocks_size_pos,
397     total_pcm_frames_pos) = write_block_header(
398        writer=writer,
399        sub_blocks_size=int(sub_blocks_size),
400        total_pcm_frames=total_pcm_frames,
401        block_index=0xFFFFFFFF,
402        block_samples=0,
403        bits_per_sample=pcmreader.bits_per_sample,
404        channel_count=1,
405        joint_stereo=0,
406        cross_channel_decorrelation=0,
407        wasted_bps=0,
408        initial_block_in_sequence=1,
409        final_block_in_sequence=1,
410        maximum_magnitude=0,
411        sample_rate=pcmreader.sample_rate,
412        false_stereo=0,
413        CRC=0xFFFFFFFF)
414
415    if total_pcm_frames == 0:
416        total_pcm_frames_positions.append(total_pcm_frames_pos)
417
418    writer.add_callback(sub_blocks_size.add)
419
420    # write MD5 in final block
421    sub_block.reset()
422    sub_block.write_bytes(context.md5sum.digest())
423    write_sub_block(writer, WV_MD5, 1, sub_block)
424
425    # write Wave footer in final block, if present
426    if context.wave_footer is not None:
427        sub_block.reset()
428        sub_block.write_bytes(context.wave_footer)
429        write_sub_block(writer, WV_WAVE_FOOTER, 1, sub_block)
430
431    writer.pop_callback()
432
433    # fill in block header with total sub blocks size
434    end_of_block = writer.getpos()
435    writer.setpos(sub_blocks_size_pos)
436    writer.write(32, int(sub_blocks_size) + 24)
437    writer.setpos(end_of_block)
438
439    # update Wave header's "data" chunk size, if generated
440    if context.wave_header is None:
441        sub_block.reset()
442        writer.setpos(context.wave_header_start)
443        if context.wave_footer is None:
444            write_wave_header(sub_block, context.pcmreader,
445                              context.total_frames, 0)
446        else:
447            write_wave_header(sub_block, context.pcmreader,
448                              context.total_frames, len(context.wave_footer))
449        write_sub_block(writer, WV_WAVE_HEADER, 1, sub_block)
450
451    # go back and populate block headers with total samples
452
453    if total_pcm_frames > 0:
454        assert(block_index == total_pcm_frames)
455
456    for pos in total_pcm_frames_positions:
457        writer.setpos(pos)
458        writer.write(32, block_index)
459
460    writer.flush()
461    writer.close()
462
463
464def write_block(writer,
465                context,
466                channels,
467                total_pcm_frames,
468                block_index,
469                first_block,
470                last_block,
471                parameters):
472    """writer is a BitstreamWriter-compatible object
473    context is an EncoderContext object
474    channels[c][s] is sample "s" in channel "c"
475    block_index is an integer of the block's offset in PCM frames
476    first_block and last_block are flags indicating the block's sequence
477    parameters is an EncodingParameters object
478    """
479
480    assert((len(channels) == 1) or (len(channels) == 2))
481
482    if (len(channels) == 1) or (channels[0] == channels[1]):
483        # 1 channel block or equivalent
484        if len(channels) == 1:
485            false_stereo = 0
486        else:
487            false_stereo = 1
488
489        # calculate maximum magnitude of channel_0
490        magnitude = max(map(bits, channels[0]))
491
492        # determine wasted bits
493        wasted = min(map(wasted_bps, channels[0]))
494        if wasted == INFINITY:
495            # all samples are 0
496            wasted = 0
497
498        # if wasted bits, remove them from channel_0
499        if (wasted > 0) and (wasted != INFINITY):
500            shifted = [[s >> wasted for s in channels[0]]]
501        else:
502            shifted = [channels[0]]
503
504        # calculate CRC of shifted_0
505        crc = calculate_crc(shifted)
506    else:
507        # 2 channel block
508        false_stereo = 0
509
510        # calculate maximum magnitude of channel_0/channel_1
511        magnitude = max(max(map(bits, channels[0])),
512                        max(map(bits, channels[1])))
513
514        # determine wasted bits
515        wasted = min(min(map(wasted_bps, channels[0])),
516                     min(map(wasted_bps, channels[1])))
517        if wasted == INFINITY:
518            # all samples are 0
519            wasted = 0
520
521        # if wasted bits, remove them from channel_0/channel_1
522        if wasted > 0:
523            shifted = [[s >> wasted for s in channels[0]],
524                       [s >> wasted for s in channels[1]]]
525        else:
526            shifted = channels
527
528        # calculate CRC of shifted_0/shifted_1
529        crc = calculate_crc(shifted)
530
531        # joint stereo conversion of shifted_0/shifted_1 to mid/side channels
532        mid_side = joint_stereo(shifted[0], shifted[1])
533
534    sub_block = BitstreamRecorder(1)
535    sub_blocks_size = Counter()
536
537    # write block header with placeholder total sub blocks size
538    (sub_blocks_size_pos,
539     total_pcm_frames_pos) = write_block_header(
540        writer=writer,
541        sub_blocks_size=int(sub_blocks_size),
542        total_pcm_frames=total_pcm_frames,
543        block_index=block_index,
544        block_samples=len(channels[0]),
545        bits_per_sample=context.pcmreader.bits_per_sample,
546        channel_count=len(channels),
547        joint_stereo=(len(channels) == 2) and (false_stereo == 0),
548        cross_channel_decorrelation=len(
549            {-1, -2, -3} &
550            {p.term for p in
551             parameters.correlation_parameters(false_stereo)}) > 0,
552        wasted_bps=wasted,
553        initial_block_in_sequence=first_block,
554        final_block_in_sequence=last_block,
555        maximum_magnitude=magnitude,
556        sample_rate=context.pcmreader.sample_rate,
557        false_stereo=false_stereo,
558        CRC=crc)
559
560    writer.add_callback(sub_blocks_size.add)
561
562    # if first block in file, write Wave header
563    if not context.first_block_written:
564        sub_block.reset()
565        if context.wave_header is None:
566            if context.wave_footer is None:
567                write_wave_header(sub_block, context.pcmreader, 0, 0)
568            else:
569                write_wave_header(sub_block, context.pcmreader, 0,
570                                  len(context.wave_footer))
571            context.wave_header_start = writer.getpos()
572            write_sub_block(writer, WV_DUMMY, 0, sub_block)
573        else:
574            sub_block.write_bytes(context.wave_header)
575            write_sub_block(writer, WV_WAVE_HEADER, 1, sub_block)
576        context.first_block_written = True
577
578    # if correlation passes, write three sub blocks of pass data
579    if parameters.correlation_passes > 0:
580        sub_block.reset()
581        write_correlation_terms(
582            sub_block,
583            [p.term for p in
584             parameters.correlation_parameters(false_stereo)],
585            [p.delta for p in
586             parameters.correlation_parameters(false_stereo)])
587        write_sub_block(writer, WV_TERMS, 0, sub_block)
588
589        sub_block.reset()
590        write_correlation_weights(
591            sub_block,
592            [p.weights for p in
593             parameters.correlation_parameters(false_stereo)])
594        write_sub_block(writer, WV_WEIGHTS, 0, sub_block)
595
596        sub_block.reset()
597        write_correlation_samples(
598            sub_block,
599            [p.term for p in
600             parameters.correlation_parameters(false_stereo)],
601            [p.samples for p in
602             parameters.correlation_parameters(false_stereo)],
603            2 if ((len(channels) == 2) and (not false_stereo)) else 1)
604        write_sub_block(writer, WV_SAMPLES, 0, sub_block)
605
606    # if wasted bits, write extended integers sub block
607    if wasted > 0:
608        sub_block.reset()
609        write_extended_integers(sub_block, 0, wasted, 0, 0)
610        write_sub_block(writer, WV_INT32_INFO, 0, sub_block)
611
612    # if channel count > 2, write channel info sub block
613    if context.pcmreader.channels > 2:
614        sub_block.reset()
615        sub_block.write(8, context.pcmreader.channels)
616        sub_block.write(32, context.pcmreader.channel_mask)
617        write_sub_block(writer, WV_CHANNEL_INFO, 0, sub_block)
618
619    # if nonstandard sample rate, write sample rate sub block
620    if (context.pcmreader.sample_rate not in
621        (6000, 8000, 9600, 11025, 12000, 16000, 22050, 24000,
622         32000, 44100, 48000, 64000, 88200, 96000, 192000)):
623        sub_block.reset()
624        sub_block.write(32, context.pcmreader.sample_rate)
625        write_sub_block(writer, WV_SAMPLE_RATE, 1, sub_block)
626
627    if (len(channels) == 1) or (false_stereo):
628        # 1 channel block
629
630        # correlate shifted_0 with terms/deltas/weights/samples
631        if parameters.correlation_passes > 0:
632            assert(len(shifted) == 1)
633            correlated = correlate_channels(
634                shifted,
635                parameters.correlation_parameters(false_stereo),
636                1)
637        else:
638            correlated = shifted
639    else:
640        # 2 channel block
641
642        # correlate shifted_0/shifted_1 with terms/deltas/weights/samples
643        if parameters.correlation_passes > 0:
644            assert(len(mid_side) == 2)
645            correlated = correlate_channels(
646                mid_side,
647                parameters.correlation_parameters(false_stereo),
648                2)
649        else:
650            correlated = mid_side
651
652    # write entropy variables sub block
653    sub_block.reset()
654    write_entropy_variables(sub_block, correlated,
655                            parameters.entropy_variables)
656    write_sub_block(writer, WV_ENTROPY, 0, sub_block)
657
658    # write bitstream sub block
659    sub_block.reset()
660    write_bitstream(sub_block, correlated,
661                    parameters.entropy_variables)
662    write_sub_block(writer, WV_BITSTREAM, 0, sub_block)
663
664    writer.pop_callback()
665
666    # fill in total sub blocks size
667    end_of_block = writer.getpos()
668    writer.setpos(sub_blocks_size_pos)
669    writer.write(32, int(sub_blocks_size) + 24)
670    writer.setpos(end_of_block)
671
672    # round-trip entropy variables
673    parameters.entropy_variables = [
674        [wv_exp2(wv_log2(p)) for p in parameters.entropy_variables[0]],
675        [wv_exp2(wv_log2(p)) for p in parameters.entropy_variables[1]]]
676
677    # return total PCM frames position to be filled in later
678    return total_pcm_frames_pos
679
680
681def bits(sample):
682    sample = abs(sample)
683    total = 0
684    while sample > 0:
685        total += 1
686        sample >>= 1
687    return total
688
689INFINITY = 2 ** 32
690
691
692def wasted_bps(sample):
693    if sample == 0:
694        return INFINITY
695    else:
696        total = 0
697        while (sample % 2) == 0:
698            total += 1
699            sample //= 2
700        return total
701
702
703def calculate_crc(samples):
704    crc = 0xFFFFFFFF
705
706    for frame in zip(*samples):
707        for s in frame:
708            crc = 3 * crc + s
709
710    if crc >= 0:
711        return crc % 0x100000000
712    else:
713        return (2 ** 32 - (-crc)) % 0x100000000
714
715
716def joint_stereo(left, right):
717    assert(len(left) == len(right))
718
719    mid = []
720    side = []
721    for (l, r) in zip(left, right):
722        mid.append(l - r)
723        side.append((l + r) // 2)
724    return [mid, side]
725
726
727def write_block_header(writer,
728                       sub_blocks_size,
729                       total_pcm_frames,
730                       block_index,
731                       block_samples,
732                       bits_per_sample,
733                       channel_count,
734                       joint_stereo,
735                       cross_channel_decorrelation,
736                       wasted_bps,
737                       initial_block_in_sequence,
738                       final_block_in_sequence,
739                       maximum_magnitude,
740                       sample_rate,
741                       false_stereo,
742                       CRC):
743    writer.write_bytes(b"wvpk")             # block ID
744    # block size
745    sub_blocks_size_pos = writer.getpos()
746    if sub_blocks_size > 0:
747        writer.write(32, sub_blocks_size + 24)
748    else:
749        writer.write(32, 0)
750    writer.write(16, 0x0410)                # version
751    writer.write(8, 0)                      # track number
752    writer.write(8, 0)                      # index number
753    total_pcm_frames_pos = writer.getpos()
754    writer.write(32, total_pcm_frames)
755    writer.write(32, block_index)
756    writer.write(32, block_samples)
757    writer.write(2, (bits_per_sample // 8) - 1)
758    writer.write(1, 2 - channel_count)
759    writer.write(1, 0)                      # hybrid mode
760    writer.write(1, joint_stereo)
761    writer.write(1, cross_channel_decorrelation)
762    writer.write(1, 0)                      # hybrid noise shaping
763    writer.write(1, 0)                      # floating point data
764    if wasted_bps > 0:                    # extended size integers
765        writer.write(1, 1)
766    else:
767        writer.write(1, 0)
768    writer.write(1, 0)                      # hybrid controls bitrate
769    writer.write(1, 0)                      # hybrid noise balanced
770    writer.write(1, initial_block_in_sequence)
771    writer.write(1, final_block_in_sequence)
772    writer.write(5, 0)                      # left shift data
773    writer.write(5, maximum_magnitude)
774    writer.write(4, {6000: 0,
775                     8000: 1,
776                     9600: 2,
777                     11025: 3,
778                     12000: 4,
779                     16000: 5,
780                     22050: 6,
781                     24000: 7,
782                     32000: 8,
783                     44100: 9,
784                     48000: 10,
785                     64000: 11,
786                     88200: 12,
787                     96000: 13,
788                     192000: 14}.get(sample_rate, 15))
789    writer.write(2, 0)                      # reserved
790    writer.write(1, 0)                      # use IIR
791    writer.write(1, false_stereo)
792    writer.write(1, 0)                      # reserved
793    writer.write(32, CRC)
794
795    return (sub_blocks_size_pos, total_pcm_frames_pos)
796
797
798def write_sub_block(writer, function, nondecoder_data, recorder):
799    recorder.byte_align()
800
801    actual_size_1_less = recorder.bytes() % 2
802
803    writer.build("5u 1u 1u",
804                 (function,
805                  nondecoder_data,
806                  actual_size_1_less))
807
808    if recorder.bytes() > (255 * 2):
809        writer.write(1, 1)
810        writer.write(24, (recorder.bytes() // 2) + actual_size_1_less)
811    else:
812        writer.write(1, 0)
813        writer.write(8, (recorder.bytes() // 2) + actual_size_1_less)
814
815    recorder.copy(writer)
816
817    if actual_size_1_less:
818        writer.write(8, 0)
819
820
821def write_correlation_terms(writer, correlation_terms, correlation_deltas):
822    """correlation_terms[p] and correlation_deltas[p]
823    are ints for each correlation pass, in descending order
824    writes the terms and deltas to sub block data in the proper order/format"""
825
826    assert(len(correlation_terms) == len(correlation_deltas))
827
828    for (term, delta) in zip(correlation_terms, correlation_deltas):
829        writer.write(5, term + 5)
830        writer.write(3, delta)
831
832
833def write_correlation_weights(writer, correlation_weights):
834    """correlation_weights[p][c]
835    are lists of correlation weight ints for each pass and channel
836    in descending order
837    writes the weights to sub block data in the proper order/format"""
838
839    for weights in correlation_weights:
840        for weight in weights:
841            writer.write(8, store_weight(weight))
842
843
844def store_weight(w):
845    w = min(max(w, -1024), 1024)
846
847    if w > 0:
848        return (w - ((w + 2 ** 6) // 2 ** 7) + 4) // (2 ** 3)
849    elif w == 0:
850        return 0
851    elif w < 0:
852        return (w + 4) // (2 ** 3)
853
854
855def restore_weight(v):
856    if v > 0:
857        return ((v * 2 ** 3) + ((v * 2 ** 3 + 2 ** 6) // 2 ** 7))
858    elif(v == 0):
859        return 0
860    else:
861        return v * (2 ** 3)
862
863
864def write_correlation_samples(writer, correlation_terms, correlation_samples,
865                              channel_count):
866    """correlation_terms[p] are correlation term ints for each pass
867
868    correlation_samples[p][c][s] are lists of correlation sample ints
869    for each pass and channel in descending order
870
871    writes the samples to sub block data in the proper order/format"""
872
873    assert(len(correlation_terms) == len(correlation_samples))
874
875    if channel_count == 2:
876        for (term, samples) in zip(correlation_terms, correlation_samples):
877            if (17 <= term) and (term <= 18):
878                writer.write_signed(16, wv_log2(samples[0][0]))
879                writer.write_signed(16, wv_log2(samples[0][1]))
880                writer.write_signed(16, wv_log2(samples[1][0]))
881                writer.write_signed(16, wv_log2(samples[1][1]))
882            elif (1 <= term) and (term <= 8):
883                for s in range(term):
884                    writer.write_signed(16, wv_log2(samples[0][s]))
885                    writer.write_signed(16, wv_log2(samples[1][s]))
886            elif (-3 <= term) and (term <= -1):
887                writer.write_signed(16, wv_log2(samples[0][0]))
888                writer.write_signed(16, wv_log2(samples[1][0]))
889            else:
890                raise ValueError("invalid correlation term")
891    elif channel_count == 1:
892        for (term, samples) in zip(correlation_terms, correlation_samples):
893            if (17 <= term) and (term <= 18):
894                writer.write_signed(16, wv_log2(samples[0][0]))
895                writer.write_signed(16, wv_log2(samples[0][1]))
896            elif (1 <= term) and (term <= 8):
897                for s in range(term):
898                    writer.write_signed(16, wv_log2(samples[0][s]))
899            else:
900                raise ValueError("invalid correlation term")
901    else:
902        raise ValueError("invalid channel count")
903
904
905def wv_log2(value):
906    from math import log
907
908    a = abs(value) + (abs(value) // 2 ** 9)
909    if a != 0:
910        c = int(log(a) / log(2)) + 1
911    else:
912        c = 0
913    if value > 0:
914        if (0 <= a) and (a < 256):
915            return (c * 2 ** 8) + WLOG[(a * 2 ** (9 - c)) % 256]
916        else:
917            return (c * 2 ** 8) + WLOG[(a // 2 ** (c - 9)) % 256]
918    else:
919        if (0 <= a) and (a < 256):
920            return -((c * 2 ** 8) + WLOG[(a * 2 ** (9 - c)) % 256])
921        else:
922            return -((c * 2 ** 8) + WLOG[(a // 2 ** (c - 9)) % 256])
923
924
925WLOG = [0x00, 0x01, 0x03, 0x04, 0x06, 0x07, 0x09, 0x0a,
926        0x0b, 0x0d, 0x0e, 0x10, 0x11, 0x12, 0x14, 0x15,
927        0x16, 0x18, 0x19, 0x1a, 0x1c, 0x1d, 0x1e, 0x20,
928        0x21, 0x22, 0x24, 0x25, 0x26, 0x28, 0x29, 0x2a,
929        0x2c, 0x2d, 0x2e, 0x2f, 0x31, 0x32, 0x33, 0x34,
930        0x36, 0x37, 0x38, 0x39, 0x3b, 0x3c, 0x3d, 0x3e,
931        0x3f, 0x41, 0x42, 0x43, 0x44, 0x45, 0x47, 0x48,
932        0x49, 0x4a, 0x4b, 0x4d, 0x4e, 0x4f, 0x50, 0x51,
933        0x52, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a,
934        0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63,
935        0x64, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c,
936        0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x74, 0x75,
937        0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d,
938        0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85,
939        0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d,
940        0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95,
941        0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9b, 0x9c,
942        0x9d, 0x9e, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4,
943        0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xa9, 0xaa, 0xab,
944        0xac, 0xad, 0xae, 0xaf, 0xb0, 0xb1, 0xb2, 0xb2,
945        0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xb9,
946        0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, 0xc0, 0xc0,
947        0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc6, 0xc7,
948        0xc8, 0xc9, 0xca, 0xcb, 0xcb, 0xcc, 0xcd, 0xce,
949        0xcf, 0xd0, 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd4,
950        0xd5, 0xd6, 0xd7, 0xd8, 0xd8, 0xd9, 0xda, 0xdb,
951        0xdc, 0xdc, 0xdd, 0xde, 0xdf, 0xe0, 0xe0, 0xe1,
952        0xe2, 0xe3, 0xe4, 0xe4, 0xe5, 0xe6, 0xe7, 0xe7,
953        0xe8, 0xe9, 0xea, 0xea, 0xeb, 0xec, 0xed, 0xee,
954        0xee, 0xef, 0xf0, 0xf1, 0xf1, 0xf2, 0xf3, 0xf4,
955        0xf4, 0xf5, 0xf6, 0xf7, 0xf7, 0xf8, 0xf9, 0xf9,
956        0xfa, 0xfb, 0xfc, 0xfc, 0xfd, 0xfe, 0xff, 0xff]
957
958
959def wv_exp2(value):
960    if (-32768 <= value) and (value < -2304):
961        return -(WEXP[-value & 0xFF] << ((-value >> 8) - 9))
962    elif (-2304 <= value) and (value < 0):
963        return -(WEXP[-value & 0xFF] >> (9 - (-value >> 8)))
964    elif (0 <= value) and (value <= 2304):
965        return WEXP[value & 0xFF] >> (9 - (value >> 8))
966    elif (2304 < value) and (value <= 32767):
967        return WEXP[value & 0xFF] << ((value >> 8) - 9)
968
969
970WEXP = [0x100, 0x101, 0x101, 0x102, 0x103, 0x103, 0x104, 0x105,
971        0x106, 0x106, 0x107, 0x108, 0x108, 0x109, 0x10a, 0x10b,
972        0x10b, 0x10c, 0x10d, 0x10e, 0x10e, 0x10f, 0x110, 0x110,
973        0x111, 0x112, 0x113, 0x113, 0x114, 0x115, 0x116, 0x116,
974        0x117, 0x118, 0x119, 0x119, 0x11a, 0x11b, 0x11c, 0x11d,
975        0x11d, 0x11e, 0x11f, 0x120, 0x120, 0x121, 0x122, 0x123,
976        0x124, 0x124, 0x125, 0x126, 0x127, 0x128, 0x128, 0x129,
977        0x12a, 0x12b, 0x12c, 0x12c, 0x12d, 0x12e, 0x12f, 0x130,
978        0x130, 0x131, 0x132, 0x133, 0x134, 0x135, 0x135, 0x136,
979        0x137, 0x138, 0x139, 0x13a, 0x13a, 0x13b, 0x13c, 0x13d,
980        0x13e, 0x13f, 0x140, 0x141, 0x141, 0x142, 0x143, 0x144,
981        0x145, 0x146, 0x147, 0x148, 0x148, 0x149, 0x14a, 0x14b,
982        0x14c, 0x14d, 0x14e, 0x14f, 0x150, 0x151, 0x151, 0x152,
983        0x153, 0x154, 0x155, 0x156, 0x157, 0x158, 0x159, 0x15a,
984        0x15b, 0x15c, 0x15d, 0x15e, 0x15e, 0x15f, 0x160, 0x161,
985        0x162, 0x163, 0x164, 0x165, 0x166, 0x167, 0x168, 0x169,
986        0x16a, 0x16b, 0x16c, 0x16d, 0x16e, 0x16f, 0x170, 0x171,
987        0x172, 0x173, 0x174, 0x175, 0x176, 0x177, 0x178, 0x179,
988        0x17a, 0x17b, 0x17c, 0x17d, 0x17e, 0x17f, 0x180, 0x181,
989        0x182, 0x183, 0x184, 0x185, 0x187, 0x188, 0x189, 0x18a,
990        0x18b, 0x18c, 0x18d, 0x18e, 0x18f, 0x190, 0x191, 0x192,
991        0x193, 0x195, 0x196, 0x197, 0x198, 0x199, 0x19a, 0x19b,
992        0x19c, 0x19d, 0x19f, 0x1a0, 0x1a1, 0x1a2, 0x1a3, 0x1a4,
993        0x1a5, 0x1a6, 0x1a8, 0x1a9, 0x1aa, 0x1ab, 0x1ac, 0x1ad,
994        0x1af, 0x1b0, 0x1b1, 0x1b2, 0x1b3, 0x1b4, 0x1b6, 0x1b7,
995        0x1b8, 0x1b9, 0x1ba, 0x1bc, 0x1bd, 0x1be, 0x1bf, 0x1c0,
996        0x1c2, 0x1c3, 0x1c4, 0x1c5, 0x1c6, 0x1c8, 0x1c9, 0x1ca,
997        0x1cb, 0x1cd, 0x1ce, 0x1cf, 0x1d0, 0x1d2, 0x1d3, 0x1d4,
998        0x1d6, 0x1d7, 0x1d8, 0x1d9, 0x1db, 0x1dc, 0x1dd, 0x1de,
999        0x1e0, 0x1e1, 0x1e2, 0x1e4, 0x1e5, 0x1e6, 0x1e8, 0x1e9,
1000        0x1ea, 0x1ec, 0x1ed, 0x1ee, 0x1f0, 0x1f1, 0x1f2, 0x1f4,
1001        0x1f5, 0x1f6, 0x1f8, 0x1f9, 0x1fa, 0x1fc, 0x1fd, 0x1ff]
1002
1003
1004def correlate_channels(uncorrelated_samples,
1005                       correlation_parameters,
1006                       channel_count):
1007    """uncorrelated_samples[c][s] is sample 's' for channel 'c'
1008    correlation_parameters is a list of CorrelationParameters objects
1009    which are updated by each pass
1010    returns correlated_samples[c][s] with sample 's' for channel 'c'
1011    """
1012
1013    if channel_count == 1:
1014        latest_pass = uncorrelated_samples[0]
1015        for p in correlation_parameters:
1016            (latest_pass,
1017             weight,
1018             samples) = correlation_pass_1ch(latest_pass,
1019                                             p.term,
1020                                             p.delta,
1021                                             p.weights[0],
1022                                             p.samples[0])
1023            p.update_weights([weight])
1024            p.update_samples([samples])
1025        return [latest_pass]
1026    else:
1027        latest_pass = uncorrelated_samples
1028        for p in correlation_parameters:
1029            (latest_pass,
1030             weights,
1031             samples) = correlation_pass_2ch(latest_pass,
1032                                             p.term,
1033                                             p.delta,
1034                                             p.weights,
1035                                             p.samples)
1036            p.update_weights(weights)
1037            p.update_samples(samples)
1038        return latest_pass
1039
1040
1041def correlation_pass_1ch(uncorrelated_samples,
1042                         term, delta, weight, correlation_samples):
1043    """given a list of uncorrelated_samples[s]
1044    term, delta and weight ints
1045    and a list of correlation_samples[s] ints
1046    returns a (correlated[s], weight, samples[s]) tuple
1047    containing correlated samples and updated weight/samples"""
1048
1049    if term == 18:
1050        assert(len(correlation_samples) == 2)
1051        uncorrelated = ([correlation_samples[1],
1052                         correlation_samples[0]] +
1053                        uncorrelated_samples)
1054        correlated = []
1055        for i in range(2, len(uncorrelated)):
1056            temp = (3 * uncorrelated[i - 1] - uncorrelated[i - 2]) // 2
1057            correlated.append(uncorrelated[i] - apply_weight(weight, temp))
1058            weight += update_weight(temp, correlated[i - 2], delta)
1059        return (correlated, weight, list(reversed(correlated[-2:])))
1060    elif term == 17:
1061        assert(len(correlation_samples) == 2)
1062        uncorrelated = ([correlation_samples[1],
1063                         correlation_samples[0]] +
1064                        uncorrelated_samples)
1065        correlated = []
1066        for i in range(2, len(uncorrelated)):
1067            temp = 2 * uncorrelated[i - 1] - uncorrelated[i - 2]
1068            correlated.append(uncorrelated[i] - apply_weight(weight, temp))
1069            weight += update_weight(temp, correlated[i - 2], delta)
1070        return (correlated, weight, list(reversed(correlated[-2:])))
1071    elif (1 <= term) and (term <= 8):
1072        assert(len(correlation_samples) == term)
1073        uncorrelated = correlation_samples[:] + uncorrelated_samples
1074        correlated = []
1075        for i in range(term, len(uncorrelated)):
1076            correlated.append(uncorrelated[i] -
1077                              apply_weight(weight, uncorrelated[i - term]))
1078            weight += update_weight(uncorrelated[i - term],
1079                                    correlated[i - term], delta)
1080        return (correlated, weight, correlated[-term:])
1081    else:
1082        raise ValueError("unsupported term")
1083
1084
1085def correlation_pass_2ch(uncorrelated_samples,
1086                         term, delta, weights, correlation_samples):
1087    """given a list of uncorrelated_samples[c][s] lists
1088    term and delta ints
1089    a list of weight[c] ints
1090    and a list of correlation_samples[c][s] lists
1091    returns (correlated[c][s], weights[c], samples[c][s]) tuple
1092    containing correlated samples and updated weights/samples"""
1093
1094    assert(len(uncorrelated_samples) == 2)
1095    assert(len(uncorrelated_samples[0]) == len(uncorrelated_samples[1]))
1096    assert(len(weights) == 2)
1097
1098    if ((17 <= term) and (term <= 18)) or ((1 <= term) and (term <= 8)):
1099        (uncorrelated1,
1100         weight1,
1101         samples1) = correlation_pass_1ch(uncorrelated_samples[0],
1102                                          term, delta, weights[0],
1103                                          correlation_samples[0])
1104        (uncorrelated2,
1105         weight2,
1106         samples2) = correlation_pass_1ch(uncorrelated_samples[1],
1107                                          term, delta, weights[1],
1108                                          correlation_samples[1])
1109        return ([uncorrelated1, uncorrelated2],
1110                [weight1, weight2],
1111                [samples1, samples2])
1112
1113    elif (-3 <= term) and (term <= -1):
1114        assert(len(correlation_samples[0]) == 1)
1115        assert(len(correlation_samples[1]) == 1)
1116        uncorrelated = (correlation_samples[1] + uncorrelated_samples[0],
1117                        correlation_samples[0] + uncorrelated_samples[1])
1118        correlated = [[], []]
1119        weights = list(weights)
1120        if term == -1:
1121            for i in range(1, len(uncorrelated[0])):
1122                correlated[0].append(uncorrelated[0][i] -
1123                                     apply_weight(weights[0],
1124                                                  uncorrelated[1][i - 1]))
1125                correlated[1].append(uncorrelated[1][i] -
1126                                     apply_weight(weights[1],
1127                                                  uncorrelated[0][i]))
1128                weights[0] += update_weight(uncorrelated[1][i - 1],
1129                                            correlated[0][-1],
1130                                            delta)
1131                weights[1] += update_weight(uncorrelated[0][i],
1132                                            correlated[1][-1],
1133                                            delta)
1134                weights[0] = max(min(weights[0], 1024), -1024)
1135                weights[1] = max(min(weights[1], 1024), -1024)
1136        elif term == -2:
1137            for i in range(1, len(uncorrelated[0])):
1138                correlated[0].append(uncorrelated[0][i] -
1139                                     apply_weight(weights[0],
1140                                                  uncorrelated[1][i]))
1141                correlated[1].append(uncorrelated[1][i] -
1142                                     apply_weight(weights[1],
1143                                                  uncorrelated[0][i - 1]))
1144                weights[0] += update_weight(uncorrelated[1][i],
1145                                            correlated[0][-1],
1146                                            delta)
1147                weights[1] += update_weight(uncorrelated[0][i - 1],
1148                                            correlated[1][-1],
1149                                            delta)
1150                weights[0] = max(min(weights[0], 1024), -1024)
1151                weights[1] = max(min(weights[1], 1024), -1024)
1152        elif term == -3:
1153            for i in range(1, len(uncorrelated[0])):
1154                correlated[0].append(uncorrelated[0][i] -
1155                                     apply_weight(weights[0],
1156                                                  uncorrelated[1][i - 1]))
1157                correlated[1].append(uncorrelated[1][i] -
1158                                     apply_weight(weights[1],
1159                                                  uncorrelated[0][i - 1]))
1160                weights[0] += update_weight(uncorrelated[1][i - 1],
1161                                            correlated[0][-1],
1162                                            delta)
1163                weights[1] += update_weight(uncorrelated[0][i - 1],
1164                                            correlated[1][-1],
1165                                            delta)
1166                weights[0] = max(min(weights[0], 1024), -1024)
1167                weights[1] = max(min(weights[1], 1024), -1024)
1168
1169        # FIXME - use proper end-of-stream correlation samples
1170        return (correlated, weights, correlation_samples)
1171    else:
1172        raise ValueError("unsupported term")
1173
1174
1175def apply_weight(weight, sample):
1176    return ((weight * sample) + 512) >> 10
1177
1178
1179def update_weight(source, result, delta):
1180    if (source == 0) or (result == 0):
1181        return 0
1182    elif (source ^ result) >= 0:
1183        return delta
1184    else:
1185        return -delta
1186
1187
1188def write_entropy_variables(writer, channels, entropies):
1189    if len(channels) == 2:
1190        for e in entropies[0]:
1191            writer.write(16, wv_log2(e))
1192        for e in entropies[1]:
1193            writer.write(16, wv_log2(e))
1194    else:
1195        for e in entropies[0]:
1196            writer.write(16, wv_log2(e))
1197
1198
1199class Residual(object):
1200    def __init__(self, zeroes, m, offset, add, sign):
1201        self.zeroes = zeroes
1202        self.m = m
1203        self.offset = offset
1204        self.add = add
1205        self.sign = sign
1206
1207    def __repr__(self):
1208        return "Residual(%s, %s, %s, %s, %s)" % \
1209            (repr(self.zeroes),
1210             repr(self.m),
1211             repr(self.offset),
1212             repr(self.add),
1213             repr(self.sign))
1214
1215    @classmethod
1216    def encode(cls, residual, entropy):
1217        """given a residual integer and list of three entropies
1218        returns a Residual object and updates the entropies"""
1219
1220        # figure out unsigned from signed
1221        if residual >= 0:
1222            unsigned = residual
1223            sign = 0
1224        else:
1225            unsigned = -residual - 1
1226            sign = 1
1227
1228        medians = [e // 2 ** 4 + 1 for e in entropy]
1229
1230        # figure out m, offset, add and update channel's entropies
1231        if unsigned < medians[0]:
1232            m = 0
1233            offset = unsigned
1234            add = medians[0] - 1
1235            entropy[0] -= ((entropy[0] + 126) // 128) * 2
1236        elif (unsigned - medians[0]) < medians[1]:
1237            m = 1
1238            offset = unsigned - medians[0]
1239            add = medians[1] - 1
1240            entropy[0] += ((entropy[0] + 128) // 128) * 5
1241            entropy[1] -= ((entropy[1] + 62) // 64) * 2
1242        elif (unsigned - (medians[0] + medians[1])) < medians[2]:
1243            m = 2
1244            offset = unsigned - (medians[0] + medians[1])
1245            add = medians[2] - 1
1246            entropy[0] += ((entropy[0] + 128) // 128) * 5
1247            entropy[1] += ((entropy[1] + 64) // 64) * 5
1248            entropy[2] -= ((entropy[2] + 30) // 32) * 2
1249        else:
1250            m = (((unsigned - (medians[0] + medians[1])) // medians[2]) + 2)
1251            offset = (unsigned -
1252                      (medians[0] + medians[1] + ((m - 2) * medians[2])))
1253            add = medians[2] - 1
1254            entropy[0] += ((entropy[0] + 128) // 128) * 5
1255            entropy[1] += ((entropy[1] + 64) // 64) * 5
1256            entropy[2] += ((entropy[2] + 32) // 32) * 5
1257
1258        # zeroes will be populated later
1259        return cls(zeroes=None, m=m, offset=offset, add=add, sign=sign)
1260
1261    def flush(self, writer, u_i_2, m_i):
1262        """given a BitstreamWriter, u_{i - 2} and m_{i},
1263        encodes residual_{i - 1}'s values to disk"""
1264
1265        from math import log
1266
1267        if self.zeroes is not None:
1268            write_egc(writer, self.zeroes)
1269
1270        if self.m is not None:
1271            # calculate unary_{i - 1} based on m_{i}
1272            if (self.m > 0) and (m_i > 0):
1273                # positive m to positive m
1274                if (u_i_2 is None) or (u_i_2 % 2 == 0):
1275                    u_i_1 = (self.m * 2) + 1
1276                else:
1277                    # passing 1 from previous u
1278                    u_i_1 = (self.m * 2) - 1
1279            elif (self.m == 0) and (m_i > 0):
1280                # zero m to positive m
1281                if (u_i_2 is None) or (u_i_2 % 2 == 1):
1282                    u_i_1 = 1
1283                else:
1284                    # passing 0 from previous u
1285                    u_i_1 = None
1286            elif (self.m > 0) and (m_i == 0):
1287                # positive m to zero m
1288                if (u_i_2 is None) or (u_i_2 % 2 == 0):
1289                    u_i_1 = self.m * 2
1290                else:
1291                    # passing 1 from previous u
1292                    u_i_1 = (self.m - 1) * 2
1293            elif (self.m == 0) and (m_i == 0):
1294                # zero m to zero m
1295                if (u_i_2 is None) or (u_i_2 % 2 == 1):
1296                    u_i_1 = 0
1297                else:
1298                    # passing 0 from previous u
1299                    u_i_1 = None
1300            else:
1301                raise ValueError("invalid m")
1302
1303            # write residual_{i - 1} to disk based on unary_{i - 1}
1304            if u_i_1 is not None:
1305                if u_i_1 < 16:
1306                    writer.unary(0, u_i_1)
1307                else:
1308                    writer.unary(0, 16)
1309                    write_egc(writer, u_i_1 - 16)
1310
1311            if self.add > 0:
1312                p = int(log(self.add) / log(2))
1313                e = 2 ** (p + 1) - self.add - 1
1314                if self.offset < e:
1315                    writer.write(p, self.offset)
1316                else:
1317                    writer.write(p, (self.offset + e) // 2)
1318                    writer.write(1, (self.offset + e) % 2)
1319
1320            writer.write(1, self.sign)
1321        else:
1322            u_i_1 = None
1323
1324        return u_i_1
1325
1326
1327def write_bitstream(writer, channels, entropies):
1328    # residual_{-1}
1329    r_i_1 = Residual(zeroes=None, m=None, offset=None, add=None, sign=None)
1330
1331    # u_{-2}
1332    u_i_2 = None
1333
1334    i = 0
1335
1336    while i < (len(channels) * len(channels[0])):
1337        r = channels[i % len(channels)][i // len(channels)]
1338
1339        if (((entropies[0][0] < 2) and (entropies[1][0] < 2) and
1340             unary_undefined(u_i_2, r_i_1.m))):
1341            if (r_i_1.zeroes is not None) and (r_i_1.m is None):
1342                # in a block of zeroes
1343                if r == 0:
1344                    # continue block of zeroes
1345                    r_i_1.zeroes += 1
1346                else:
1347                    # end block of zeroes
1348                    r_i = Residual.encode(r, entropies[i % len(channels)])
1349                    r_i.zeroes = r_i_1.zeroes
1350                    r_i_1 = r_i
1351            else:
1352                # start a new block of zeroes
1353                if r == 0:
1354                    r_i = Residual(zeroes=1,
1355                                   m=None, offset=None, add=None, sign=None)
1356                    u_i_2 = r_i_1.flush(writer, u_i_2, 0)
1357                    entropies[0][0] = entropies[0][1] = entropies[0][2] = 0
1358                    entropies[1][0] = entropies[1][1] = entropies[1][2] = 0
1359                    r_i_1 = r_i
1360                else:
1361                    # false alarm block of zeroes
1362                    r_i = Residual.encode(r, entropies[i % len(channels)])
1363                    r_i.zeroes = 0
1364                    u_i_2 = r_i_1.flush(writer, u_i_2, r_i.m)
1365                    r_i_1 = r_i
1366        else:
1367            # encode regular residual
1368            r_i = Residual.encode(r, entropies[i % len(channels)])
1369            r_i.zeroes = None
1370            u_i_2 = r_i_1.flush(writer, u_i_2, r_i.m)
1371            r_i_1 = r_i
1372
1373        i += 1
1374
1375    # flush final residual
1376    u_i_2 = r_i_1.flush(writer, u_i_2, 0)
1377
1378
1379def unary_undefined(prev_u, m):
1380    """given u_{i - 1} and m_{i},
1381    returns True if u_{i} is undefined,
1382    False if defined"""
1383
1384    if m is None:
1385        return True
1386    if (m == 0) and (prev_u is not None) and (prev_u % 2 == 0):
1387        return True
1388    else:
1389        return False
1390
1391
1392def write_egc(writer, value):
1393    from math import log
1394
1395    assert(value >= 0)
1396
1397    if value > 1:
1398        t = int(log(value) / log(2)) + 1
1399        writer.unary(0, t)
1400        writer.write(t - 1, value % (2 ** (t - 1)))
1401    else:
1402        writer.unary(0, value)
1403
1404
1405def write_residual(writer, u, offset, add, sign):
1406    """given u_{i}, offset_{i}, add_{i} and sign_{i}
1407    writes residual data to the given BitstreamWriter
1408    u_{i} may be None, indicated an undefined unary value"""
1409
1410
1411def write_extended_integers(writer,
1412                            sent_bits, zero_bits, one_bits, duplicate_bits):
1413    writer.build("8u 8u 8u 8u",
1414                 (sent_bits, zero_bits, one_bits, duplicate_bits))
1415
1416
1417if (__name__ == '__main__'):
1418    write_bitstream(None,
1419                    [[1, 2, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1420                      0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -2, -3, -2, -1]],
1421                    [[0, 0, 0], [0, 0, 0]])
1422