1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""
18    Arrow.ArrowVector
19
20An abstract type that subtypes `AbstractVector`. Each specific arrow array type
21subtypes `ArrowVector`. See [`BoolVector`](@ref), [`Primitive`](@ref), [`List`](@ref),
22[`Map`](@ref), [`FixedSizeList`](@ref), [`Struct`](@ref), [`DenseUnion`](@ref),
23[`SparseUnion`](@ref), and [`DictEncoded`](@ref) for more details.
24"""
25abstract type ArrowVector{T} <: AbstractVector{T} end
26
27Base.IndexStyle(::Type{A}) where {A <: ArrowVector} = Base.IndexLinear()
28Base.similar(::Type{A}, dims::Dims) where {T, A <: ArrowVector{T}} = Vector{T}(undef, dims)
29validitybitmap(x::ArrowVector) = x.validity
30nullcount(x::ArrowVector) = validitybitmap(x).nc
31getmetadata(x::ArrowVector) = x.metadata
32
33function toarrowvector(x, i=1, de=Dict{Int64, Any}(), ded=DictEncoding[], meta=getmetadata(x); compression::Union{Nothing, LZ4FrameCompressor, ZstdCompressor}=nothing, kw...)
34    @debug 2 "converting top-level column to arrow format: col = $(typeof(x)), compression = $compression, kw = $(kw.data)"
35    @debug 3 x
36    A = arrowvector(x, i, 0, 0, de, ded, meta; compression=compression, kw...)
37    if compression isa LZ4FrameCompressor
38        A = compress(Meta.CompressionType.LZ4_FRAME, compression, A)
39    elseif compression isa ZstdCompressor
40        A = compress(Meta.CompressionType.ZSTD, compression, A)
41    end
42    @debug 2 "converted top-level column to arrow format: $(typeof(A))"
43    @debug 3 A
44    return A
45end
46
47function arrowvector(x, i, nl, fi, de, ded, meta; dictencoding::Bool=false, dictencode::Bool=false, kw...)
48    if !(x isa DictEncode) && !dictencoding && (dictencode || (x isa AbstractArray && DataAPI.refarray(x) !== x))
49        x = DictEncode(x, dictencodeid(i, nl, fi))
50    end
51    S = maybemissing(eltype(x))
52    return arrowvector(S, x, i, nl, fi, de, ded, meta; dictencode=dictencode, kw...)
53end
54
55# defaults for Dates types
56ArrowTypes.default(::Type{Dates.Date}) = Dates.Date(1,1,1)
57ArrowTypes.default(::Type{Dates.Time}) = Dates.Time(1,1,1)
58ArrowTypes.default(::Type{Dates.DateTime}) = Dates.DateTime(1,1,1,1,1,1)
59ArrowTypes.default(::Type{TimeZones.ZonedDateTime}) = TimeZones.ZonedDateTime(1,1,1,1,1,1,TimeZones.tz"UTC")
60
61# conversions to arrow types
62arrowvector(::Type{Dates.Date}, x, i, nl, fi, de, ded, meta; kw...) =
63    arrowvector(converter(DATE, x), i, nl, fi, de, ded, meta; kw...)
64arrowvector(::Type{Dates.Time}, x, i, nl, fi, de, ded, meta; kw...) =
65    arrowvector(converter(TIME, x), i, nl, fi, de, ded, meta; kw...)
66arrowvector(::Type{Dates.DateTime}, x, i, nl, fi, de, ded, meta; kw...) =
67    arrowvector(converter(DATETIME, x), i, nl, fi, de, ded, meta; kw...)
68arrowvector(::Type{ZonedDateTime}, x, i, nl, fi, de, ded, meta; kw...) =
69    arrowvector(converter(Timestamp{Meta.TimeUnit.MILLISECOND, Symbol(x[1].timezone)}, x), i, nl, fi, de, ded, meta; kw...)
70arrowvector(::Type{P}, x, i, nl, fi, de, ded, meta; kw...) where {P <: Dates.Period} =
71    arrowvector(converter(Duration{arrowperiodtype(P)}, x), i, nl, fi, de, ded, meta; kw...)
72
73# fallback that calls ArrowType
74function arrowvector(::Type{S}, x, i, nl, fi, de, ded, meta; kw...) where {S}
75    if ArrowTypes.istyperegistered(S)
76        meta = meta === nothing ? Dict{String, String}() : meta
77        arrowtype = ArrowTypes.getarrowtype!(meta, S)
78        if arrowtype === S
79            return arrowvector(ArrowType(S), x, i, nl, fi, de, ded, meta; kw...)
80        else
81            return arrowvector(converter(arrowtype, x), i, nl, fi, de, ded, meta; kw...)
82        end
83    end
84    return arrowvector(ArrowType(S), x, i, nl, fi, de, ded, meta; kw...)
85end
86
87arrowvector(::NullType, x, i, nl, fi, de, ded, meta; kw...) = MissingVector(length(x))
88compress(Z::Meta.CompressionType, comp, v::MissingVector) =
89    Compressed{Z, MissingVector}(v, CompressedBuffer[], length(v), length(v), Compressed[])
90
91function makenodesbuffers!(col::MissingVector, fieldnodes, fieldbuffers, bufferoffset, alignment)
92    push!(fieldnodes, FieldNode(length(col), length(col)))
93    @debug 1 "made field node: nodeidx = $(length(fieldnodes)), col = $(typeof(col)), len = $(fieldnodes[end].length), nc = $(fieldnodes[end].null_count)"
94    return bufferoffset
95end
96
97function writebuffer(io, col::MissingVector, alignment)
98    return
99end
100
101"""
102    Arrow.ValidityBitmap
103
104A bit-packed array type where each bit corresponds to an element in an
105[`ArrowVector`](@ref), indicating whether that element is "valid" (bit == 1),
106or not (bit == 0). Used to indicate element missingness (whether it's null).
107
108If the null count of an array is zero, the `ValidityBitmap` will be "emtpy"
109and all elements are treated as "valid"/non-null.
110"""
111struct ValidityBitmap <: ArrowVector{Bool}
112    bytes::Vector{UInt8} # arrow memory blob
113    pos::Int # starting byte of validity bitmap
114    ℓ::Int # # of _elements_ (not bytes!) in bitmap (because bitpacking)
115    nc::Int # null count
116end
117
118Base.size(p::ValidityBitmap) = (p.ℓ,)
119nullcount(x::ValidityBitmap) = x.nc
120
121function ValidityBitmap(x)
122    T = eltype(x)
123    if !(T >: Missing)
124        return ValidityBitmap(UInt8[], 1, length(x), 0)
125    end
126    len = length(x)
127    blen = cld(len, 8)
128    bytes = Vector{UInt8}(undef, blen)
129    st = iterate(x)
130    nc = 0
131    b = 0xff
132    j = k = 1
133    for y in x
134        if y === missing
135            nc += 1
136            b = setbit(b, false, j)
137        end
138        j += 1
139        if j == 9
140            @inbounds bytes[k] = b
141            b = 0xff
142            j = 1
143            k += 1
144        end
145    end
146    if j > 1
147        bytes[k] = b
148    end
149    return ValidityBitmap(nc == 0 ? UInt8[] : bytes, 1, nc == 0 ? 0 : len, nc)
150end
151
152@propagate_inbounds function Base.getindex(p::ValidityBitmap, i::Integer)
153    # no boundscheck because parent array should do it
154    # if a validity bitmap is empty, it either means:
155    #   1) the parent array null_count is 0, so all elements are valid
156    #   2) parent array is also empty, so "all" elements are valid
157    p.nc == 0 && return true
158    # translate element index to bitpacked byte index
159    a, b = fldmod1(i, 8)
160    @inbounds byte = p.bytes[p.pos + a - 1]
161    # check individual bit of byte
162    return getbit(byte, b)
163end
164
165@propagate_inbounds function Base.setindex!(p::ValidityBitmap, v, i::Integer)
166    x = convert(Bool, v)
167    p.ℓ == 0 && !x && throw(BoundsError(p, i))
168    a, b = fldmod1(i, 8)
169    @inbounds byte = p.bytes[p.pos + a - 1]
170    @inbounds p.bytes[p.pos + a - 1] = setbit(byte, x, b)
171    return v
172end
173
174function writebitmap(io, col::ArrowVector, alignment)
175    v = col.validity
176    @debug 1 "writing validity bitmap: nc = $(v.nc), n = $(cld(v.ℓ, 8))"
177    v.nc == 0 && return 0
178    n = Base.write(io, view(v.bytes, v.pos:(v.pos + cld(v.ℓ, 8) - 1)))
179    return n + writezeros(io, paddinglength(n, alignment))
180end
181
182include("compressed.jl")
183include("primitive.jl")
184include("bool.jl")
185include("list.jl")
186include("fixedsizelist.jl")
187include("map.jl")
188include("struct.jl")
189include("unions.jl")
190include("dictencoding.jl")
191