1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.serializer
19
20import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
21import java.nio.ByteBuffer
22import java.nio.charset.StandardCharsets
23
24import scala.collection.mutable
25
26import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer}
27import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
28import org.apache.avro.{Schema, SchemaNormalization}
29import org.apache.avro.generic.{GenericData, GenericRecord}
30import org.apache.avro.io._
31import org.apache.commons.io.IOUtils
32
33import org.apache.spark.{SparkEnv, SparkException}
34import org.apache.spark.io.CompressionCodec
35import org.apache.spark.util.Utils
36
37/**
38 * Custom serializer used for generic Avro records. If the user registers the schemas
39 * ahead of time, then the schema's fingerprint will be sent with each message instead of the actual
40 * schema, as to reduce network IO.
41 * Actions like parsing or compressing schemas are computationally expensive so the serializer
42 * caches all previously seen values as to reduce the amount of work needed to do.
43 * @param schemas a map where the keys are unique IDs for Avro schemas and the values are the
44 *                string representation of the Avro schema, used to decrease the amount of data
45 *                that needs to be serialized.
46 */
47private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
48  extends KSerializer[GenericRecord] {
49
50  /** Used to reduce the amount of effort to compress the schema */
51  private val compressCache = new mutable.HashMap[Schema, Array[Byte]]()
52  private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]()
53
54  /** Reuses the same datum reader/writer since the same schema will be used many times */
55  private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]()
56  private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]()
57
58  /** Fingerprinting is very expensive so this alleviates most of the work */
59  private val fingerprintCache = new mutable.HashMap[Schema, Long]()
60  private val schemaCache = new mutable.HashMap[Long, Schema]()
61
62  // GenericAvroSerializer can't take a SparkConf in the constructor b/c then it would become
63  // a member of KryoSerializer, which would make KryoSerializer not Serializable.  We make
64  // the codec lazy here just b/c in some unit tests, we use a KryoSerializer w/out having
65  // the SparkEnv set (note those tests would fail if they tried to serialize avro data).
66  private lazy val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
67
68  /**
69   * Used to compress Schemas when they are being sent over the wire.
70   * The compression results are memoized to reduce the compression time since the
71   * same schema is compressed many times over
72   */
73  def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, {
74    val bos = new ByteArrayOutputStream()
75    val out = codec.compressedOutputStream(bos)
76    Utils.tryWithSafeFinally {
77      out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
78    } {
79      out.close()
80    }
81    bos.toByteArray
82  })
83
84  /**
85   * Decompresses the schema into the actual in-memory object. Keeps an internal cache of already
86   * seen values so to limit the number of times that decompression has to be done.
87   */
88  def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, {
89    val bis = new ByteArrayInputStream(
90      schemaBytes.array(),
91      schemaBytes.arrayOffset() + schemaBytes.position(),
92      schemaBytes.remaining())
93    val in = codec.compressedInputStream(bis)
94    val bytes = Utils.tryWithSafeFinally {
95      IOUtils.toByteArray(in)
96    } {
97      in.close()
98    }
99    new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8))
100  })
101
102  /**
103   * Serializes a record to the given output stream. It caches a lot of the internal data as
104   * to not redo work
105   */
106  def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = {
107    val encoder = EncoderFactory.get.binaryEncoder(output, null)
108    val schema = datum.getSchema
109    val fingerprint = fingerprintCache.getOrElseUpdate(schema, {
110      SchemaNormalization.parsingFingerprint64(schema)
111    })
112    schemas.get(fingerprint) match {
113      case Some(_) =>
114        output.writeBoolean(true)
115        output.writeLong(fingerprint)
116      case None =>
117        output.writeBoolean(false)
118        val compressedSchema = compress(schema)
119        output.writeInt(compressedSchema.length)
120        output.writeBytes(compressedSchema)
121    }
122
123    writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema))
124      .asInstanceOf[DatumWriter[R]]
125      .write(datum, encoder)
126    encoder.flush()
127  }
128
129  /**
130   * Deserializes generic records into their in-memory form. There is internal
131   * state to keep a cache of already seen schemas and datum readers.
132   */
133  def deserializeDatum(input: KryoInput): GenericRecord = {
134    val schema = {
135      if (input.readBoolean()) {
136        val fingerprint = input.readLong()
137        schemaCache.getOrElseUpdate(fingerprint, {
138          schemas.get(fingerprint) match {
139            case Some(s) => new Schema.Parser().parse(s)
140            case None =>
141              throw new SparkException(
142                "Error reading attempting to read avro data -- encountered an unknown " +
143                  s"fingerprint: $fingerprint, not sure what schema to use.  This could happen " +
144                  "if you registered additional schemas after starting your spark context.")
145          }
146        })
147      } else {
148        val length = input.readInt()
149        decompress(ByteBuffer.wrap(input.readBytes(length)))
150      }
151    }
152    val decoder = DecoderFactory.get.directBinaryDecoder(input, null)
153    readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema))
154      .asInstanceOf[DatumReader[GenericRecord]]
155      .read(null, decoder)
156  }
157
158  override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit =
159    serializeDatum(datum, output)
160
161  override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord =
162    deserializeDatum(input)
163}
164