1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""
18    Arrow.DictEncoding
19
20Represents the "pool" of possible values for a [`DictEncoded`](@ref)
21array type. Whether the order of values is significant can be checked
22by looking at the `isOrdered` boolean field.
23"""
24mutable struct DictEncoding{T, A} <: ArrowVector{T}
25    id::Int64
26    data::A
27    isOrdered::Bool
28    metadata::Union{Nothing, Dict{String, String}}
29end
30
31Base.size(d::DictEncoding) = size(d.data)
32
33@propagate_inbounds function Base.getindex(d::DictEncoding{T}, i::Integer) where {T}
34    @boundscheck checkbounds(d, i)
35    return @inbounds ArrowTypes.arrowconvert(T, d.data[i])
36end
37
38# convenience wrapper to signal that an input column should be
39# dict encoded when written to the arrow format
40struct DictEncodeType{T} end
41getT(::Type{DictEncodeType{T}}) where {T} = T
42
43"""
44    Arrow.DictEncode(::AbstractVector, id::Integer=nothing)
45
46Signals that a column/array should be dictionary encoded when serialized
47to the arrow streaming/file format. An optional `id` number may be provided
48to signal that multiple columns should use the same pool when being
49dictionary encoded.
50"""
51struct DictEncode{T, A} <: AbstractVector{DictEncodeType{T}}
52    id::Int64
53    data::A
54end
55
56DictEncode(x::A, id=-1) where {A} = DictEncode{eltype(A), A}(id, x)
57Base.IndexStyle(::Type{<:DictEncode}) = Base.IndexLinear()
58Base.size(x::DictEncode) = (length(x.data),)
59Base.iterate(x::DictEncode, st...) = iterate(x.data, st...)
60Base.getindex(x::DictEncode, i::Int) = getindex(x.data, i)
61ArrowTypes.ArrowType(::Type{<:DictEncodeType}) = DictEncodedType()
62
63"""
64    Arrow.DictEncoded
65
66A dictionary encoded array type (similar to a `PooledArray`). Behaves just
67like a normal array in most respects; internally, possible values are stored
68in the `encoding::DictEncoding` field, while the `indices::Vector{<:Integer}`
69field holds the "codes" of each element for indexing into the encoding pool.
70Any column/array can be dict encoding when serializing to the arrow format
71either by passing the `dictencode=true` keyword argument to [`Arrow.write`](@ref)
72(which causes _all_ columns to be dict encoded), or wrapping individual columns/
73arrays in [`Arrow.DictEncode(x)`](@ref).
74"""
75struct DictEncoded{T, S, A} <: ArrowVector{T}
76    arrow::Vector{UInt8} # need to hold a reference to arrow memory blob
77    validity::ValidityBitmap
78    indices::Vector{S}
79    encoding::DictEncoding{T, A}
80    metadata::Union{Nothing, Dict{String, String}}
81end
82
83DictEncoded(b::Vector{UInt8}, v::ValidityBitmap, inds::Vector{S}, encoding::DictEncoding{T, A}, meta) where {S, T, A} =
84    DictEncoded{T, S, A}(b, v, inds, encoding, meta)
85
86Base.size(d::DictEncoded) = size(d.indices)
87
88isdictencoded(d::DictEncoded) = true
89isdictencoded(x) = false
90isdictencoded(c::Compressed{Z, A}) where {Z, A <: DictEncoded} = true
91
92signedtype(::Type{UInt8}) = Int8
93signedtype(::Type{UInt16}) = Int16
94signedtype(::Type{UInt32}) = Int32
95signedtype(::Type{UInt64}) = Int64
96
97indtype(d::DictEncoded{T, S, A}) where {T, S, A} = S
98indtype(c::Compressed{Z, A}) where {Z, A <: DictEncoded} = indtype(c.data)
99
100dictencodeid(colidx, nestedlevel, fieldid) = (Int64(nestedlevel) << 48) | (Int64(fieldid) << 32) | Int64(colidx)
101
102getid(d::DictEncoded) = d.encoding.id
103getid(c::Compressed{Z, A}) where {Z, A <: DictEncoded} = c.data.encoding.id
104
105arrowvector(::DictEncodedType, x::DictEncoded, i, nl, fi, de, ded, meta; kw...) = x
106
107function arrowvector(::DictEncodedType, x, i, nl, fi, de, ded, meta; dictencode::Bool=false, dictencodenested::Bool=false, kw...)
108    @assert x isa DictEncode
109    id = x.id == -1 ? dictencodeid(i, nl, fi) : x.id
110    x = x.data
111    len = length(x)
112    validity = ValidityBitmap(x)
113    if !haskey(de, id)
114        # dict encoding doesn't exist yet, so create for 1st time
115        if DataAPI.refarray(x) === x
116            # need to encode ourselves
117            x = PooledArray(x, encodingtype(length(x)))
118            inds = DataAPI.refarray(x)
119        else
120            inds = copy(DataAPI.refarray(x))
121        end
122        # adjust to "offset" instead of index
123        for i = 1:length(inds)
124            @inbounds inds[i] -= 1
125        end
126        pool = DataAPI.refpool(x)
127        # horrible hack? yes. better than taking CategoricalArrays dependency? also yes.
128        if typeof(pool).name.name == :CategoricalRefPool
129            pool = [get(pool[i]) for i = 1:length(pool)]
130        end
131        data = arrowvector(pool, i, nl, fi, de, ded, nothing; dictencode=dictencodenested, dictencodenested=dictencodenested, dictencoding=true, kw...)
132        encoding = DictEncoding{eltype(data), typeof(data)}(id, data, false, getmetadata(data))
133        de[id] = Lockable(encoding)
134    else
135        # encoding already exists
136          # compute inds based on it
137          # if value doesn't exist in encoding, push! it
138          # also add to deltas updates
139        encodinglockable = de[id]
140        @lock encodinglockable begin
141            encoding = encodinglockable.x
142            len = length(x)
143            ET = encodingtype(len)
144            pool = Dict{Union{eltype(encoding), eltype(x)}, ET}(a => (b - 1) for (b, a) in enumerate(encoding))
145            deltas = eltype(x)[]
146            inds = Vector{ET}(undef, len)
147            categorical = typeof(x).name.name == :CategoricalArray
148            for (j, val) in enumerate(x)
149                if categorical
150                    val = get(val)
151                end
152                @inbounds inds[j] = get!(pool, val) do
153                    push!(deltas, val)
154                    length(pool)
155                end
156            end
157            if !isempty(deltas)
158                data = arrowvector(deltas, i, nl, fi, de, ded, nothing; dictencode=dictencodenested, dictencodenested=dictencodenested, dictencoding=true, kw...)
159                push!(ded, DictEncoding{eltype(data), typeof(data)}(id, data, false, getmetadata(data)))
160                if typeof(encoding.data) <: ChainedVector
161                    append!(encoding.data, data)
162                else
163                    data2 = ChainedVector([encoding.data, data])
164                    encoding = DictEncoding{eltype(data2), typeof(data2)}(id, data2, false, getmetadata(encoding))
165                    de[id] = Lockable(encoding)
166                end
167            end
168        end
169    end
170    if meta !== nothing && getmetadata(encoding) !== nothing
171        merge!(meta, getmetadata(encoding))
172    elseif getmetadata(encoding) !== nothing
173        meta = getmetadata(encoding)
174    end
175    return DictEncoded(UInt8[], validity, inds, encoding, meta)
176end
177
178@propagate_inbounds function Base.getindex(d::DictEncoded, i::Integer)
179    @boundscheck checkbounds(d, i)
180    @inbounds valid = d.validity[i]
181    !valid && return missing
182    @inbounds idx = d.indices[i]
183    return @inbounds d.encoding[idx + 1]
184end
185
186@propagate_inbounds function Base.setindex!(d::DictEncoded{T}, v, i::Integer) where {T}
187    @boundscheck checkbounds(d, i)
188    if v === missing
189        @inbounds d.validity[i] = false
190    else
191        ix = findfirst(d.encoding.data, v)
192        if ix === nothing
193            push!(d.encoding.data, v)
194            @inbounds d.indices[i] = length(d.encoding.data) - 1
195        else
196            @inbounds d.indices[i] = ix - 1
197        end
198    end
199    return v
200end
201
202function Base.copy(x::DictEncoded{T, S}) where {T, S}
203    pool = copy(x.encoding.data)
204    valid = x.validity
205    inds = x.indices
206    refs = copy(inds)
207    @inbounds for i = 1:length(inds)
208        refs[i] = refs[i] + one(S)
209    end
210    return PooledArray(PooledArrays.RefArray(refs), Dict{T, S}(val => i for (i, val) in enumerate(pool)), pool)
211end
212
213function compress(Z::Meta.CompressionType, comp, x::A) where {A <: DictEncoded}
214    len = length(x)
215    nc = nullcount(x)
216    validity = compress(Z, comp, x.validity)
217    inds = compress(Z, comp, x.indices)
218    return Compressed{Z, A}(x, [validity, inds], len, nc, Compressed[])
219end
220
221function makenodesbuffers!(col::DictEncoded{T, S}, fieldnodes, fieldbuffers, bufferoffset, alignment) where {T, S}
222    len = length(col)
223    nc = nullcount(col)
224    push!(fieldnodes, FieldNode(len, nc))
225    @debug 1 "made field node: nodeidx = $(length(fieldnodes)), col = $(typeof(col)), len = $(fieldnodes[end].length), nc = $(fieldnodes[end].null_count)"
226    # validity bitmap
227    blen = nc == 0 ? 0 : bitpackedbytes(len, alignment)
228    push!(fieldbuffers, Buffer(bufferoffset, blen))
229    @debug 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
230    bufferoffset += blen
231    # indices
232    blen = sizeof(S) * len
233    push!(fieldbuffers, Buffer(bufferoffset, blen))
234    @debug 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
235    bufferoffset += padding(blen, alignment)
236    return bufferoffset
237end
238
239function writebuffer(io, col::DictEncoded, alignment)
240    @debug 1 "writebuffer: col = $(typeof(col))"
241    @debug 2 col
242    writebitmap(io, col, alignment)
243    # write indices
244    n = writearray(io, col.indices)
245    @debug 1 "writing array: col = $(typeof(col.indices)), n = $n, padded = $(padding(n, alignment))"
246    writezeros(io, paddinglength(n, alignment))
247    return
248end
249