1module Groonga
2  module ExpressionRewriters
3    class Optimizer < ExpressionRewriter
4      register "optimizer"
5
6      def rewrite
7        builder = ExpressionTreeBuilder.new(@expression)
8        root_node = builder.build
9
10        variable = @expression[0]
11        table = context[variable.domain]
12        optimized_root_node = optimize_node(table, root_node)
13
14        rewritten = Expression.create(table)
15        optimized_root_node.build(rewritten)
16        rewritten
17      end
18
19      private
20      def optimize_node(table, node)
21        case node
22        when ExpressionTree::LogicalOperation
23          optimized_sub_nodes = node.nodes.collect do |sub_node|
24            optimize_node(table, sub_node)
25          end
26          case node.operator
27          when Operator::AND
28            optimized_sub_nodes =
29              optimize_and_sub_nodes(table, optimized_sub_nodes)
30          end
31          ExpressionTree::LogicalOperation.new(node.operator,
32                                               optimized_sub_nodes)
33        when ExpressionTree::BinaryOperation
34          optimized_left = optimize_node(table, node.left)
35          optimized_right = optimize_node(table, node.right)
36          if optimized_left.is_a?(ExpressionTree::Constant) and
37              optimized_right.is_a?(ExpressionTree::Variable)
38            ExpressionTree::BinaryOperation.new(node.operator,
39                                                optimized_right,
40                                                optimized_left)
41          elsif node.left == optimized_left and node.right == optimized_right
42            node
43          else
44            ExpressionTree::BinaryOperation.new(node.operator,
45                                                optimized_left,
46                                                optimized_right)
47          end
48        else
49          node
50        end
51      end
52
53      def optimize_and_sub_nodes(table, sub_nodes)
54        grouped_sub_nodes = sub_nodes.group_by do |sub_node|
55          case sub_node
56          when ExpressionTree::BinaryOperation
57            if sub_node.left.is_a?(ExpressionTree::Variable)
58              sub_node.left.column
59            else
60              nil
61            end
62          else
63            nil
64          end
65        end
66
67        optimized_nodes = []
68        grouped_sub_nodes.each do |column, grouped_nodes|
69          if column
70            grouped_nodes = optimize_grouped_nodes(column, grouped_nodes)
71          end
72          optimized_nodes.concat(grouped_nodes)
73        end
74
75        optimized_nodes.sort_by do |node|
76          node.estimate_size(table)
77        end
78      end
79
80      COMPARISON_OPERATORS = [
81        Operator::EQUAL,
82        Operator::NOT_EQUAL,
83        Operator::LESS,
84        Operator::GREATER,
85        Operator::LESS_EQUAL,
86        Operator::GREATER_EQUAL,
87      ]
88      def optimize_grouped_nodes(column, grouped_nodes)
89        target_nodes, done_nodes = grouped_nodes.partition do |node|
90          node.is_a?(ExpressionTree::BinaryOperation) and
91            COMPARISON_OPERATORS.include?(node.operator) and
92            node.right.is_a?(ExpressionTree::Constant)
93        end
94
95        # TODO: target_nodes = remove_needless_nodes(target_nodes)
96        # e.g.: x < 1 && x < 3 -> x < 1: (x < 3) is meaningless
97
98        if target_nodes.size == 2
99          between_node = try_optimize_between(column, target_nodes)
100          if between_node
101            done_nodes << between_node
102          else
103            done_nodes.concat(target_nodes)
104          end
105        else
106          done_nodes.concat(target_nodes)
107        end
108
109        done_nodes
110      end
111
112      def try_optimize_between(column, target_nodes)
113        greater_node = nil
114        less_node = nil
115        target_nodes.each do |node|
116          case node.operator
117          when Operator::GREATER, Operator::GREATER_EQUAL
118            greater_node = node
119          when Operator::LESS, Operator::LESS_EQUAL
120            less_node = node
121          end
122        end
123        return nil if greater_node.nil? or less_node.nil?
124
125        between = ExpressionTree::Procedure.new(context["between"])
126        if greater_node.operator == Operator::GREATER
127          greater_border = "exclude"
128        else
129          greater_border = "include"
130        end
131        if less_node.operator == Operator::LESS
132          less_border = "exclude"
133        else
134          less_border = "include"
135        end
136        arguments = [
137          ExpressionTree::Variable.new(column),
138          greater_node.right,
139          ExpressionTree::Constant.new(greater_border),
140          less_node.right,
141          ExpressionTree::Constant.new(less_border),
142        ]
143        ExpressionTree::FunctionCall.new(between, arguments)
144      end
145    end
146  end
147end
148