1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.expressions.codegen
19
20import java.io.ObjectInputStream
21
22import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
23import com.esotericsoftware.kryo.io.{Input, Output}
24
25import org.apache.spark.internal.Logging
26import org.apache.spark.sql.catalyst.InternalRow
27import org.apache.spark.sql.catalyst.expressions._
28import org.apache.spark.sql.types.StructType
29import org.apache.spark.util.Utils
30
31/**
32 * Inherits some default implementation for Java from `Ordering[Row]`
33 */
34class BaseOrdering extends Ordering[InternalRow] {
35  def compare(a: InternalRow, b: InternalRow): Int = {
36    throw new UnsupportedOperationException
37  }
38}
39
40/**
41 * Generates bytecode for an [[Ordering]] of rows for a given set of expressions.
42 */
43object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging {
44
45  protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
46    in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
47
48  protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
49    in.map(BindReferences.bindReference(_, inputSchema))
50
51  /**
52   * Creates a code gen ordering for sorting this schema, in ascending order.
53   */
54  def create(schema: StructType): BaseOrdering = {
55    create(schema.zipWithIndex.map { case (field, ordinal) =>
56      SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending)
57    })
58  }
59
60  /**
61   * Generates the code for comparing a struct type according to its natural ordering
62   * (i.e. ascending order by field 1, then field 2, ..., then field n.
63   */
64  def genComparisons(ctx: CodegenContext, schema: StructType): String = {
65    val ordering = schema.fields.map(_.dataType).zipWithIndex.map {
66      case(dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending)
67    }
68    genComparisons(ctx, ordering)
69  }
70
71  /**
72   * Generates the code for ordering based on the given order.
73   */
74  def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = {
75    val comparisons = ordering.map { order =>
76      val oldCurrentVars = ctx.currentVars
77      ctx.INPUT_ROW = "i"
78      // to use INPUT_ROW we must make sure currentVars is null
79      ctx.currentVars = null
80      val eval = order.child.genCode(ctx)
81      ctx.currentVars = oldCurrentVars
82      val asc = order.isAscending
83      val isNullA = ctx.freshName("isNullA")
84      val primitiveA = ctx.freshName("primitiveA")
85      val isNullB = ctx.freshName("isNullB")
86      val primitiveB = ctx.freshName("primitiveB")
87      s"""
88          ${ctx.INPUT_ROW} = a;
89          boolean $isNullA;
90          ${ctx.javaType(order.child.dataType)} $primitiveA;
91          {
92            ${eval.code}
93            $isNullA = ${eval.isNull};
94            $primitiveA = ${eval.value};
95          }
96          ${ctx.INPUT_ROW} = b;
97          boolean $isNullB;
98          ${ctx.javaType(order.child.dataType)} $primitiveB;
99          {
100            ${eval.code}
101            $isNullB = ${eval.isNull};
102            $primitiveB = ${eval.value};
103          }
104          if ($isNullA && $isNullB) {
105            // Nothing
106          } else if ($isNullA) {
107            return ${
108              order.nullOrdering match {
109                case NullsFirst => "-1"
110                case NullsLast => "1"
111              }};
112          } else if ($isNullB) {
113            return ${
114              order.nullOrdering match {
115                case NullsFirst => "1"
116                case NullsLast => "-1"
117              }};
118          } else {
119            int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)};
120            if (comp != 0) {
121              return ${if (asc) "comp" else "-comp"};
122            }
123          }
124      """
125    }
126
127    val code = ctx.splitExpressions(
128      expressions = comparisons,
129      funcName = "compare",
130      arguments = Seq(("InternalRow", "a"), ("InternalRow", "b")),
131      returnType = "int",
132      makeSplitFunction = { body =>
133        s"""
134          InternalRow ${ctx.INPUT_ROW} = null;  // Holds current row being evaluated.
135          $body
136          return 0;
137        """
138      },
139      foldFunctions = { funCalls =>
140        funCalls.zipWithIndex.map { case (funCall, i) =>
141          val comp = ctx.freshName("comp")
142          s"""
143            int $comp = $funCall;
144            if ($comp != 0) {
145              return $comp;
146            }
147          """
148        }.mkString
149      })
150    // make sure INPUT_ROW is declared even if splitExpressions
151    // returns an inlined block
152    s"""
153       |InternalRow ${ctx.INPUT_ROW} = null;
154       |$code
155     """.stripMargin
156  }
157
158  protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
159    val ctx = newCodeGenContext()
160    val comparisons = genComparisons(ctx, ordering)
161    val codeBody = s"""
162      public SpecificOrdering generate(Object[] references) {
163        return new SpecificOrdering(references);
164      }
165
166      class SpecificOrdering extends ${classOf[BaseOrdering].getName} {
167
168        private Object[] references;
169        ${ctx.declareMutableStates()}
170
171        public SpecificOrdering(Object[] references) {
172          this.references = references;
173          ${ctx.initMutableStates()}
174        }
175
176        ${ctx.declareAddedFunctions()}
177
178        public int compare(InternalRow a, InternalRow b) {
179          $comparisons
180          return 0;
181        }
182      }"""
183
184    val code = CodeFormatter.stripOverlappingComments(
185      new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
186    logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}")
187
188    CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
189  }
190}
191
192/**
193 * A lazily generated row ordering comparator.
194 */
195class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
196  extends Ordering[InternalRow] with KryoSerializable {
197
198  def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
199    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
200
201  @transient
202  private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
203
204  def compare(a: InternalRow, b: InternalRow): Int = {
205    generatedOrdering.compare(a, b)
206  }
207
208  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
209    in.defaultReadObject()
210    generatedOrdering = GenerateOrdering.generate(ordering)
211  }
212
213  override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
214    kryo.writeObject(out, ordering.toArray)
215  }
216
217  override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
218    generatedOrdering = GenerateOrdering.generate(kryo.readObject(in, classOf[Array[SortOrder]]))
219  }
220}
221
222object LazilyGeneratedOrdering {
223
224  /**
225   * Creates a [[LazilyGeneratedOrdering]] for the given schema, in natural ascending order.
226   */
227  def forSchema(schema: StructType): LazilyGeneratedOrdering = {
228    new LazilyGeneratedOrdering(schema.zipWithIndex.map {
229      case (field, ordinal) =>
230        SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending)
231    })
232  }
233}
234