1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18module Arrow
19  # Experimental
20  #
21  # TODO: Almost codes should be implemented in Apache Arrow C++.
22  class Group
23    def initialize(table, keys)
24      @table = table
25      @keys = keys
26    end
27
28    def count
29      key_names = @keys.collect(&:to_s)
30      target_columns = @table.columns.reject do |column|
31        key_names.include?(column.name)
32      end
33      aggregate(target_columns) do |column, indexes|
34        n = 0
35        indexes.each do |index|
36          n += 1 unless column.null?(index)
37        end
38        n
39      end
40    end
41
42    def sum
43      key_names = @keys.collect(&:to_s)
44      target_columns = @table.columns.reject do |column|
45        key_names.include?(column.name) or
46          not column.data_type.is_a?(NumericDataType)
47      end
48      aggregate(target_columns) do |column, indexes|
49        n = 0
50        indexes.each do |index|
51          value = column[index]
52          n += value unless value.nil?
53        end
54        n
55      end
56    end
57
58    def average
59      key_names = @keys.collect(&:to_s)
60      target_columns = @table.columns.reject do |column|
61        key_names.include?(column.name) or
62          not column.data_type.is_a?(NumericDataType)
63      end
64      aggregate(target_columns) do |column, indexes|
65        average = 0.0
66        n = 0
67        indexes.each do |index|
68          value = column[index]
69          unless value.nil?
70            n += 1
71            average += (value - average) / n
72          end
73        end
74        average
75      end
76    end
77
78    def min
79      key_names = @keys.collect(&:to_s)
80      target_columns = @table.columns.reject do |column|
81        key_names.include?(column.name) or
82          not column.data_type.is_a?(NumericDataType)
83      end
84      aggregate(target_columns) do |column, indexes|
85        n = nil
86        indexes.each do |index|
87          value = column[index]
88          next if value.nil?
89          n ||= value
90          n = value if value < n
91        end
92        n
93      end
94    end
95
96    def max
97      key_names = @keys.collect(&:to_s)
98      target_columns = @table.columns.reject do |column|
99        key_names.include?(column.name) or
100          not column.data_type.is_a?(NumericDataType)
101      end
102      aggregate(target_columns) do |column, indexes|
103        n = nil
104        indexes.each do |index|
105          value = column[index]
106          next if value.nil?
107          n ||= value
108          n = value if value > n
109        end
110        n
111      end
112    end
113
114    private
115    def aggregate(target_columns)
116      sort_values = @table.n_rows.times.collect do |i|
117        key_values = @keys.collect do |key|
118          @table[key][i]
119        end
120        [key_values, i]
121      end
122      sorted = sort_values.sort_by do |key_values, i|
123        key_values
124      end
125
126      grouped_keys = []
127      aggregated_arrays_raw = []
128      target_columns.size.times do
129        aggregated_arrays_raw << []
130      end
131      indexes = []
132      sorted.each do |key_values, i|
133        if grouped_keys.empty?
134          grouped_keys << key_values
135          indexes.clear
136          indexes << i
137        else
138          if key_values == grouped_keys.last
139            indexes << i
140          else
141            grouped_keys << key_values
142            target_columns.each_with_index do |column, j|
143              aggregated_arrays_raw[j] << yield(column, indexes)
144            end
145            indexes.clear
146            indexes << i
147          end
148        end
149      end
150      target_columns.each_with_index do |column, j|
151        aggregated_arrays_raw[j] << yield(column, indexes)
152      end
153
154      grouped_key_arrays_raw = grouped_keys.transpose
155      fields = []
156      arrays = []
157      @keys.each_with_index do |key, i|
158        key_column = @table[key]
159        key_column_array_raw = grouped_key_arrays_raw[i]
160        key_column_array = key_column.data_type.build_array(key_column_array_raw)
161        fields << key_column.field
162        arrays << key_column_array
163      end
164      target_columns.each_with_index do |column, i|
165        array = ArrayBuilder.build(aggregated_arrays_raw[i])
166        arrays << array
167        fields << Field.new(column.field.name, array.value_data_type)
168      end
169      Table.new(fields, arrays)
170    end
171  end
172end
173