1from __future__ import division
2
3__all__ = [ 'ArrayTree', 'FileArrayTreeDict', 'array_tree_dict_from_reader' ]
4
5import numpy
6from numpy import *
7cimport numpy
8
9cimport bx.arrays.wiggle
10
11from bx.misc.binary_file import BinaryFileWriter, BinaryFileReader
12from bx.misc.cdb import FileCDBDict
13
14"""
15Classes for storing binary data on disk in a tree structure that allows for
16efficient sparse storage (when the data occurs in contiguous blocks), fast
17access to a specific block of data, and fast access to summaries at different
18resolutions.
19
20On disk format
21--------------
22
23Blocks are stored contiguously on disk in level-order. Contents should always be
24network byte order (big endian), however this implementation will byte-swap when
25reading  if necessary. File contents:
26
27- magic:      uint32
28- version:    unit32
29- array size: uint32
30- block size: uint32
31- array type: 4 chars (numpy typecode, currently only simple types represented by one char are supported)
32
33- Internal nodes in level order
34    - Summary
35        - count of valid values in each subtree : sizeof( dtype ) * block_size
36        - frequencies: sizeof ( int32 ) * block_size
37        - min of valid values in each subtree : sizeof( dtype ) * block_size
38        - max of valid values in each subtree : sizeof( dtype ) * block_size
39        - sum of valid values in each subtree : sizeof( dtype ) * block_size
40        - sum of squares of valid values in each subtree : sizeof( dtype ) * block_size
41    - File offsets of each child node: uint64 * block_size
42
43- Leaf nodes
44    - data points: sizeof( dtype ) * block_size
45
46- Version 1 reads version 0 and version 1
47
48"""
49
50## Enhancement ideas:
51##
52##   - Write markers of the number of blocks skipped between blocks. This would
53##     allow fast striding across data or summaries (use the indexes to get to
54##     the start of a block, then just read straight through). Would this help?
55##
56##   - Compression for blocks?
57
58MAGIC = 0x310ec7dc
59VERSION = 1
60NUM_SUMMARY_ARRAYS = 6
61
62def array_tree_dict_from_reader( reader, sizes, default_size=2147483647, block_size=1000, no_leaves=False ):
63    # Create empty array trees
64    rval = {}
65    ## for key, size in sizes.iteritems():
66    ##    rval[ key ] = ArrayTree( size, 1000 )
67    # Fill
68    last_chrom = None
69    last_array_tree = None
70    for chrom, start, end, _, val in reader:
71        if chrom != last_chrom:
72            if chrom not in rval:
73                rval[chrom] = ArrayTree( sizes.get( chrom, default_size ), block_size, no_leaves=no_leaves )
74            last_array_tree = rval[chrom]
75        last_array_tree.set_range( start, end, val )
76    return rval
77
78
79cdef class FileArrayTreeDict:
80    """
81    Access to a file containing multiple array trees indexed by a string key.
82    """
83    cdef object io
84    cdef object cdb_dict
85    def __init__( self, file ):
86        self.io = io = BinaryFileReader( file, MAGIC )
87        assert (0 <= io.read_uint32() <= 1) # Check for version 0 or 1
88        self.cdb_dict = FileCDBDict( file, is_little_endian=io.is_little_endian )
89    def __getitem__( self, key ):
90        offset = self.cdb_dict[key]
91        offset = self.io.unpack( "L", offset.encode() )[0]
92        self.io.seek( offset )
93        return FileArrayTree( self.io.file, self.io.is_little_endian )
94
95    @classmethod
96    def dict_to_file( Class, dict, file, is_little_endian=True, no_leaves=False ):
97        """
98        Writes a dictionary of array trees to a file that can then be
99        read efficiently using this class.
100        """
101        io = BinaryFileWriter( file, is_little_endian=is_little_endian )
102        # Write magic number and version
103        io.write_uint32( MAGIC )
104        io.write_uint32( VERSION )
105        # Write cdb index with fake values just to fill space
106        cdb_dict = {}
107        for key in dict.iterkeys():
108            cdb_dict[ key ] = io.pack( "L", 0 )
109        cdb_offset = io.tell()
110        FileCDBDict.to_file( cdb_dict, file, is_little_endian=is_little_endian )
111        # Write each tree and save offset
112        for key, value in dict.iteritems():
113            offset = io.tell()
114            cdb_dict[ key ] = io.pack( "L", offset )
115            value.to_file( file, is_little_endian=is_little_endian, no_leaves=no_leaves )
116        # Go back and write the index again
117        io.seek( cdb_offset )
118        FileCDBDict.to_file( cdb_dict, file, is_little_endian=is_little_endian )
119
120cdef class FileArrayTree:
121    """
122    Wrapper for ArrayTree stored in file that reads as little as possible
123    """
124    cdef public int max
125    cdef public int block_size
126    cdef public object dtype
127    cdef public int levels
128    cdef public int offset
129    cdef public int root_offset
130    cdef object io
131
132    def __init__( self, file, is_little_endian=True ):
133        self.io = BinaryFileReader( file, is_little_endian=is_little_endian )
134        self.offset = self.io.tell()
135        # Read basic info about the tree
136        self.max = self.io.read_uint32()
137        self.block_size = self.io.read_uint32()
138        # Read dtype and canonicalize
139        dt = self.io.read( 1 )
140        self.dtype = numpy.dtype( dt )
141        self.io.skip( 3 )
142        # How many levels are needed to cover the entire range?
143        self.levels = 0
144        while ( <long long> self.block_size ) ** ( self.levels + 1 ) < self.max:
145            self.levels += 1
146        # Not yet dealing with the case where the root is a Leaf
147        assert self.levels > 0, "max < block_size not yet handled"
148        # Save offset of root
149        self.root_offset = self.io.tell()
150
151    def __getitem__( self, index ):
152        min = self.r_seek_to_node( index, 0, self.root_offset, self.levels, 0 )
153        if min < 0:
154            return nan
155        self.io.skip( self.dtype.itemsize * ( index - min ) )
156        return self.io.read_raw_array( self.dtype, 1 )[0]
157
158    def get_summary( self, index, level ):
159        if level <= 0 or level > self.levels:
160            raise ValueError, "level must be <= self.levels"
161        if self.r_seek_to_node( index, 0, self.root_offset, self.levels, level ) < 0:
162            return None
163        # Read summary arrays
164        s = Summary()
165        s.counts = self.io.read_raw_array( self.dtype, self.block_size )
166        s.frequencies = self.io.read_raw_array( self.dtype, self.block_size )
167        s.sums = self.io.read_raw_array( self.dtype, self.block_size )
168        s.mins = self.io.read_raw_array( self.dtype, self.block_size)
169        s.maxs = self.io.read_raw_array( self.dtype, self.block_size )
170        s.sumsquares = self.io.read_raw_array( self.dtype, self.block_size )
171        return s
172
173    def get_leaf( self, index ):
174        if self.r_seek_to_node( index, 0, self.root_offset, self.levels, 0 ) < 0:
175            return []
176        return self.io.read_raw_array( self.dtype, self.block_size )
177
178    cdef int r_seek_to_node( self, int index, int min, long long offset, int level, int desired_level ):
179        """
180        Seek to the start of the node at `desired_level` that contains `index`.
181        Returns the minimum value represented in that node.
182        """
183        cdef int child_size, bin_index, child_min
184        self.io.seek( offset )
185        if level > desired_level:
186            child_size = self.block_size ** level
187            bin_index = ( index - min ) // ( child_size )
188            child_min = min + ( bin_index * child_size )
189            # Skip summary arrays -- # arrays * itemsize * block_size
190            self.io.skip( NUM_SUMMARY_ARRAYS * self.dtype.itemsize * self.block_size )
191            # Skip to offset of correct child -- offsets are 8 bytes
192            self.io.skip( 8 * bin_index )
193            # Read offset of child
194            child_offset = self.io.read_uint64()
195            # print "co: %s\tbi: %s\tcm: %s\n" % (child_offset, bin_index, child_min)
196            if child_offset == 0:
197                return -1
198            return self.r_seek_to_node( index, child_min, child_offset, level - 1, desired_level )
199        else:
200            # The file pointer is at the start of the desired node, do nothing
201            return min
202
203cdef class Summary:
204    """
205    Summary for a non-leaf level of the tree, contains arrays of the min, max,
206    valid count, sum, and sum-of-squares for each child.
207    """
208    cdef public object counts
209    cdef public object frequencies
210    cdef public object mins
211    cdef public object maxs
212    cdef public object sums
213    cdef public object sumsquares
214
215cdef class ArrayTreeNode
216cdef class ArrayTreeLeaf
217
218cdef class ArrayTree:
219    """
220    Stores a sparse array of data as a tree.
221
222    An array of `self.max` values is stored in a tree in which each leaf
223    contains `self.block_size` values and each internal node contains
224    `self.block_size` children.
225
226    Entirely empty subtrees are not stored. Thus, the storage is efficient for
227    data that is block sparse -- having contiguous chunks of `self.block_size` or
228    larger data. Currently it is not efficient if the data is strided (e.g.
229    one or two data points in every interval of length `self.block_size`).
230
231    Internal nodes store `Summary` instances for their subtrees.
232    """
233
234    cdef public int max
235    cdef public int block_size
236    cdef public object dtype
237    cdef public int levels
238    cdef public int no_leaves
239    cdef public ArrayTreeNode root
240
241    def __init__( self, int max, int block_size, dtype=float32, no_leaves=False ):
242        """
243        Create a new array tree of size `max`
244        """
245        self.max = max
246        self.block_size = block_size
247        self.no_leaves = no_leaves
248        # Force the dtype argument to its canonical dtype object
249        self.dtype = numpy.dtype( dtype )
250        # How many levels are needed to cover the entire range?
251        self.levels = 0
252        while ( <long long> self.block_size ) ** ( self.levels + 1 ) < self.max:
253            self.levels += 1
254        # Not yet dealing with the case where the root is a Leaf
255        assert self.levels > 0, "max < block_size not yet handled"
256        # Create the root node`
257        self.root = ArrayTreeNode( self, 0, max, block_size, self.levels )
258
259    def __setitem__( self, int index, value ):
260        self.root.set( index, value )
261
262    def set_range( self, int start, int end, value ):
263        for i from start <= i < end:
264            self.root.set( i, value )
265
266    def __getitem__( self, int index ):
267        return self.root.get( index )
268
269    def to_file( self, f, is_little_endian=True, no_leaves=False ):
270        io = BinaryFileWriter( f, is_little_endian=is_little_endian )
271        ## io.write_uint32( VERSION )
272        io.write_uint32( self.max )
273        io.write_uint32( self.block_size )
274        io.write( self.dtype.char )
275        io.write( "\0\0\0" )
276        # Data pass, level order
277        if no_leaves:
278            bottom_level = 0
279        else:
280            bottom_level = -1
281        for level in range( self.levels, bottom_level, -1 ):
282            self.root.to_file_data_pass( io, level )
283        # Offset pass to fix up indexes
284        self.root.to_file_offset_pass( io )
285
286    @classmethod
287    def from_file( Class, f, is_little_endian=True ):
288        io = BinaryFileReader( f, is_little_endian=is_little_endian )
289        ## assert io.read_uint32() == VERSION
290        max = io.read_uint32()
291        block_size = io.read_uint32()
292        dt = io.read( 1 )
293        io.read( 3 )
294        tree = Class( max, block_size, dt )
295        tree.root.from_file( io )
296        return tree
297
298    @classmethod
299    def from_sequence( Class, s, block_size=1000 ):
300        """
301        Build an ArrayTree from a sequence like object (must have at least
302        length and getitem).
303        """
304        tree = Class( len( s ), block_size )
305        for i in range( len( s ) ):
306            tree[i] = s[i]
307        return tree
308
309cdef class ArrayTreeNode:
310    """
311    Internal node of an ArrayTree. Contains summary data and pointers to
312    subtrees.
313    """
314
315    cdef ArrayTree tree
316    cdef int min
317    cdef int max
318    cdef int block_size
319    cdef int level
320    cdef int child_size
321    cdef object children
322    cdef public Summary summary
323    cdef public long start_offset
324
325    def __init__( self, ArrayTree tree, int min, int max, int block_size, int level ):
326        self.tree = tree
327        self.min = min
328        self.max = max
329        self.block_size = block_size
330        self.level = level
331        # Each of my children represents block_size ** level values
332        self.child_size = self.block_size ** self.level
333        self.children = [None] * self.block_size
334        self.summary = None
335        self.start_offset = 0
336
337    cdef inline init_bin( self, int index ):
338        cdef int min = self.min + ( index * self.child_size )
339        cdef int max = min + self.child_size
340        if self.level == 1:
341            self.children[ index ] = ArrayTreeLeaf( self.tree, min, max )
342        else:
343            self.children[ index ] = ArrayTreeNode( self.tree, min, max, self.block_size, self.level - 1 )
344
345    def set( self, int index, value ):
346        cdef int bin_index = ( index - self.min ) // ( self.child_size )
347        if self.children[ bin_index ] is None:
348            self.init_bin( bin_index )
349        self.children[ bin_index ].set( index, value )
350
351    def get( self, int index ):
352        cdef int bin_index = ( index - self.min ) // ( self.child_size )
353        if self.children[ bin_index ] is None:
354            return nan
355        else:
356            return self.children[ bin_index ].get( index )
357
358    cpdef build_summary( self ):
359        """
360        Build summary of children.
361        """
362        counts = empty( self.tree.block_size, self.tree.dtype )
363        frequencies = empty( self.tree.block_size, self.tree.dtype )
364        mins = empty( self.tree.block_size, self.tree.dtype )
365        maxs = empty( self.tree.block_size, self.tree.dtype )
366        sums = empty( self.tree.block_size, self.tree.dtype )
367        sumsquares = empty( self.tree.block_size, self.tree.dtype )
368        for i in range( len( self.children ) ):
369            if self.children[i]:
370                if self.level == 1:
371                    v = self.children[i].values
372                    counts[i] = sum( ~isnan( v ) )
373                    frequencies[i] = self.children[i].frequency
374                    mins[i] = nanmin( v )
375                    maxs[i] = nanmax( v )
376                    sums[i] = nansum( v )
377                    sumsquares[i] = nansum( v ** 2 )
378                else:
379                    c = self.children[i]
380                    c.build_summary()
381                    counts[i] = sum( c.summary.counts )
382                    frequencies[i] = sum( c.summary.frequencies )
383                    mins[i] = nanmin( c.summary.mins )
384                    maxs[i] = nanmax( c.summary.maxs )
385                    sums[i] = nansum( c.summary.sums )
386                    sumsquares[i] = nansum( c.summary.sumsquares )
387            else:
388                counts[i] = 0
389                frequencies[i] = 0
390                mins[i] = nan
391                maxs[i] = nan
392                sums[i] = nan
393                sumsquares[i] = nan
394        s = Summary()
395        s.counts = counts
396        s.frequencies = frequencies
397        s.mins = mins
398        s.maxs = maxs
399        s.sums = sums
400        s.sumsquares = sumsquares
401        self.summary = s
402
403    def to_file_data_pass( self, io, level ):
404        """
405        First pass of writing to file, writes data and saves position of block.
406        """
407        assert self.summary, "Writing without summaries is currently not supported"
408        # If we are at the current level being written, write a block
409        if self.level == level:
410            # Save file offset where this block starts
411            self.start_offset = io.tell()
412            # Write out summary data
413            io.write_raw_array( self.summary.counts )
414            io.write_raw_array( self.summary.frequencies )
415            io.write_raw_array( self.summary.sums )
416            io.write_raw_array( self.summary.mins )
417            io.write_raw_array( self.summary.maxs )
418            io.write_raw_array( self.summary.sumsquares )
419            # Skip enough room for child offsets (block_size children * 64bits)
420            io.skip( self.tree.block_size * 8 )
421        # Must be writing a lower level, so recurse
422        else:
423            # Write all non-empty children
424            for i in range( len( self.children ) ):
425                if self.children[i] is not None:
426                    self.children[i].to_file_data_pass( io, level )
427
428    def to_file_offset_pass( self, io ):
429        """
430        Second pass of writing to file, seek to appropriate position and write
431        offsets of children.
432        """
433        # Seek to location of child offfsets (skip over # summary arrays)
434        skip_amount = NUM_SUMMARY_ARRAYS * self.tree.dtype.itemsize * self.block_size
435        io.seek( self.start_offset + skip_amount )
436        # Write the file offset of each child into the index
437        for child in self.children:
438            if child is None:
439                io.write_uint64( 0 )
440            else:
441                io.write_uint64( child.start_offset )
442        # Recursively write offsets in child nodes
443        for child in self.children:
444            if child is not None:
445                child.to_file_offset_pass( io )
446
447    def from_file( self, io ):
448        """
449        Load entire summary and all children into memory.
450        """
451        dtype = self.tree.dtype
452        block_size = self.tree.block_size
453        # Read summary arrays
454        s = Summary()
455        s.counts = io.read_raw_array( dtype, block_size )
456        s.frequencies = io.read_raw_array( int32, block_size )
457        s.sums = io.read_raw_array( dtype, block_size )
458        s.mins = io.read_raw_array( dtype, block_size)
459        s.maxs = io.read_raw_array( dtype, block_size )
460        s.sumsquares = io.read_raw_array( dtype, block_size )
461        self.summary = s
462        # Read offset of all children
463        child_offsets = [ io.read_uint64() for i in range( block_size ) ]
464        for i in range( block_size ):
465            if child_offsets[i] > 0:
466                self.init_bin( i )
467                io.seek( child_offsets[i] )
468                self.children[i].from_file( io )
469
470    def get_from_file( self, io, index ):
471        cdef int bin_index = ( index - self.min ) //( self.child_size )
472        if self.children[ bin_index ] is None:
473            return nan
474        else:
475            return self.children[ bin_index ].get( index )
476
477cdef class ArrayTreeLeaf:
478    """
479    Leaf node of an ArrayTree, contains data values.
480    """
481
482    cdef ArrayTree tree
483    cdef int min
484    cdef int max
485    cdef public int frequency
486    cdef public numpy.ndarray values
487    cdef public long start_offset
488
489    def __init__( self, ArrayTree tree, int min, int max ):
490        self.tree = tree
491        self.min = min
492        self.max = max
493        self.frequency = 0
494        self.values = empty( max - min, self.tree.dtype )
495        self.values[:] = nan
496        self.start_offset = 0
497
498    def set( self, index, value ):
499        self.frequency += 1
500        self.values[ index - self.min ] = value
501
502    def get( self, index ):
503        return self.values[ index - self.min ]
504
505    def to_file_data_pass( self, io, level ):
506        assert level == 0
507        self.start_offset = io.tell()
508        io.write_raw_array( self.values )
509
510    def to_file_offset_pass( self, io ):
511        pass
512
513    def from_file( self, io ):
514        self.values = io.read_raw_array( self.tree.dtype, self.tree.block_size )
515