1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.sql.catalyst.expressions.codegen 19 20import java.io.ObjectInputStream 21 22import com.esotericsoftware.kryo.{Kryo, KryoSerializable} 23import com.esotericsoftware.kryo.io.{Input, Output} 24 25import org.apache.spark.internal.Logging 26import org.apache.spark.sql.catalyst.InternalRow 27import org.apache.spark.sql.catalyst.expressions._ 28import org.apache.spark.sql.types.StructType 29import org.apache.spark.util.Utils 30 31/** 32 * Inherits some default implementation for Java from `Ordering[Row]` 33 */ 34class BaseOrdering extends Ordering[InternalRow] { 35 def compare(a: InternalRow, b: InternalRow): Int = { 36 throw new UnsupportedOperationException 37 } 38} 39 40/** 41 * Generates bytecode for an [[Ordering]] of rows for a given set of expressions. 42 */ 43object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging { 44 45 protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = 46 in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) 47 48 protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = 49 in.map(BindReferences.bindReference(_, inputSchema)) 50 51 /** 52 * Creates a code gen ordering for sorting this schema, in ascending order. 53 */ 54 def create(schema: StructType): BaseOrdering = { 55 create(schema.zipWithIndex.map { case (field, ordinal) => 56 SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending) 57 }) 58 } 59 60 /** 61 * Generates the code for comparing a struct type according to its natural ordering 62 * (i.e. ascending order by field 1, then field 2, ..., then field n. 63 */ 64 def genComparisons(ctx: CodegenContext, schema: StructType): String = { 65 val ordering = schema.fields.map(_.dataType).zipWithIndex.map { 66 case(dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending) 67 } 68 genComparisons(ctx, ordering) 69 } 70 71 /** 72 * Generates the code for ordering based on the given order. 73 */ 74 def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { 75 val comparisons = ordering.map { order => 76 val oldCurrentVars = ctx.currentVars 77 ctx.INPUT_ROW = "i" 78 // to use INPUT_ROW we must make sure currentVars is null 79 ctx.currentVars = null 80 val eval = order.child.genCode(ctx) 81 ctx.currentVars = oldCurrentVars 82 val asc = order.isAscending 83 val isNullA = ctx.freshName("isNullA") 84 val primitiveA = ctx.freshName("primitiveA") 85 val isNullB = ctx.freshName("isNullB") 86 val primitiveB = ctx.freshName("primitiveB") 87 s""" 88 ${ctx.INPUT_ROW} = a; 89 boolean $isNullA; 90 ${ctx.javaType(order.child.dataType)} $primitiveA; 91 { 92 ${eval.code} 93 $isNullA = ${eval.isNull}; 94 $primitiveA = ${eval.value}; 95 } 96 ${ctx.INPUT_ROW} = b; 97 boolean $isNullB; 98 ${ctx.javaType(order.child.dataType)} $primitiveB; 99 { 100 ${eval.code} 101 $isNullB = ${eval.isNull}; 102 $primitiveB = ${eval.value}; 103 } 104 if ($isNullA && $isNullB) { 105 // Nothing 106 } else if ($isNullA) { 107 return ${ 108 order.nullOrdering match { 109 case NullsFirst => "-1" 110 case NullsLast => "1" 111 }}; 112 } else if ($isNullB) { 113 return ${ 114 order.nullOrdering match { 115 case NullsFirst => "1" 116 case NullsLast => "-1" 117 }}; 118 } else { 119 int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; 120 if (comp != 0) { 121 return ${if (asc) "comp" else "-comp"}; 122 } 123 } 124 """ 125 } 126 127 val code = ctx.splitExpressions( 128 expressions = comparisons, 129 funcName = "compare", 130 arguments = Seq(("InternalRow", "a"), ("InternalRow", "b")), 131 returnType = "int", 132 makeSplitFunction = { body => 133 s""" 134 InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated. 135 $body 136 return 0; 137 """ 138 }, 139 foldFunctions = { funCalls => 140 funCalls.zipWithIndex.map { case (funCall, i) => 141 val comp = ctx.freshName("comp") 142 s""" 143 int $comp = $funCall; 144 if ($comp != 0) { 145 return $comp; 146 } 147 """ 148 }.mkString 149 }) 150 // make sure INPUT_ROW is declared even if splitExpressions 151 // returns an inlined block 152 s""" 153 |InternalRow ${ctx.INPUT_ROW} = null; 154 |$code 155 """.stripMargin 156 } 157 158 protected def create(ordering: Seq[SortOrder]): BaseOrdering = { 159 val ctx = newCodeGenContext() 160 val comparisons = genComparisons(ctx, ordering) 161 val codeBody = s""" 162 public SpecificOrdering generate(Object[] references) { 163 return new SpecificOrdering(references); 164 } 165 166 class SpecificOrdering extends ${classOf[BaseOrdering].getName} { 167 168 private Object[] references; 169 ${ctx.declareMutableStates()} 170 171 public SpecificOrdering(Object[] references) { 172 this.references = references; 173 ${ctx.initMutableStates()} 174 } 175 176 ${ctx.declareAddedFunctions()} 177 178 public int compare(InternalRow a, InternalRow b) { 179 $comparisons 180 return 0; 181 } 182 }""" 183 184 val code = CodeFormatter.stripOverlappingComments( 185 new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) 186 logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}") 187 188 CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] 189 } 190} 191 192/** 193 * A lazily generated row ordering comparator. 194 */ 195class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) 196 extends Ordering[InternalRow] with KryoSerializable { 197 198 def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = 199 this(ordering.map(BindReferences.bindReference(_, inputSchema))) 200 201 @transient 202 private[this] var generatedOrdering = GenerateOrdering.generate(ordering) 203 204 def compare(a: InternalRow, b: InternalRow): Int = { 205 generatedOrdering.compare(a, b) 206 } 207 208 private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { 209 in.defaultReadObject() 210 generatedOrdering = GenerateOrdering.generate(ordering) 211 } 212 213 override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException { 214 kryo.writeObject(out, ordering.toArray) 215 } 216 217 override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { 218 generatedOrdering = GenerateOrdering.generate(kryo.readObject(in, classOf[Array[SortOrder]])) 219 } 220} 221 222object LazilyGeneratedOrdering { 223 224 /** 225 * Creates a [[LazilyGeneratedOrdering]] for the given schema, in natural ascending order. 226 */ 227 def forSchema(schema: StructType): LazilyGeneratedOrdering = { 228 new LazilyGeneratedOrdering(schema.zipWithIndex.map { 229 case (field, ordinal) => 230 SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending) 231 }) 232 } 233} 234