1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.ml.recommendation
19
20import java.{util => ju}
21import java.io.IOException
22
23import scala.collection.mutable
24import scala.reflect.ClassTag
25import scala.util.{Sorting, Try}
26import scala.util.hashing.byteswap64
27
28import com.github.fommil.netlib.BLAS.{getInstance => blas}
29import org.apache.hadoop.fs.Path
30import org.json4s.DefaultFormats
31import org.json4s.JsonDSL._
32
33import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContext}
34import org.apache.spark.annotation.{DeveloperApi, Since}
35import org.apache.spark.internal.Logging
36import org.apache.spark.ml.{Estimator, Model}
37import org.apache.spark.ml.param._
38import org.apache.spark.ml.param.shared._
39import org.apache.spark.ml.util._
40import org.apache.spark.mllib.linalg.CholeskyDecomposition
41import org.apache.spark.mllib.optimization.NNLS
42import org.apache.spark.rdd.RDD
43import org.apache.spark.sql.{DataFrame, Dataset}
44import org.apache.spark.sql.functions._
45import org.apache.spark.sql.types._
46import org.apache.spark.storage.StorageLevel
47import org.apache.spark.util.Utils
48import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
49import org.apache.spark.util.random.XORShiftRandom
50
51/**
52 * Common params for ALS and ALSModel.
53 */
54private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
55  /**
56   * Param for the column name for user ids. Ids must be integers. Other
57   * numeric types are supported for this column, but will be cast to integers as long as they
58   * fall within the integer value range.
59   * Default: "user"
60   * @group param
61   */
62  val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " +
63    "the integer value range.")
64
65  /** @group getParam */
66  def getUserCol: String = $(userCol)
67
68  /**
69   * Param for the column name for item ids. Ids must be integers. Other
70   * numeric types are supported for this column, but will be cast to integers as long as they
71   * fall within the integer value range.
72   * Default: "item"
73   * @group param
74   */
75  val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " +
76    "the integer value range.")
77
78  /** @group getParam */
79  def getItemCol: String = $(itemCol)
80
81  /**
82   * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
83   * out of integer range.
84   */
85  protected val checkedCast = udf { (n: Double) =>
86    if (n > Int.MaxValue || n < Int.MinValue) {
87      throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " +
88        s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.")
89    } else {
90      n.toInt
91    }
92  }
93}
94
95/**
96 * Common params for ALS.
97 */
98private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam
99  with HasPredictionCol with HasCheckpointInterval with HasSeed {
100
101  /**
102   * Param for rank of the matrix factorization (positive).
103   * Default: 10
104   * @group param
105   */
106  val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1))
107
108  /** @group getParam */
109  def getRank: Int = $(rank)
110
111  /**
112   * Param for number of user blocks (positive).
113   * Default: 10
114   * @group param
115   */
116  val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks",
117    ParamValidators.gtEq(1))
118
119  /** @group getParam */
120  def getNumUserBlocks: Int = $(numUserBlocks)
121
122  /**
123   * Param for number of item blocks (positive).
124   * Default: 10
125   * @group param
126   */
127  val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks",
128      ParamValidators.gtEq(1))
129
130  /** @group getParam */
131  def getNumItemBlocks: Int = $(numItemBlocks)
132
133  /**
134   * Param to decide whether to use implicit preference.
135   * Default: false
136   * @group param
137   */
138  val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference")
139
140  /** @group getParam */
141  def getImplicitPrefs: Boolean = $(implicitPrefs)
142
143  /**
144   * Param for the alpha parameter in the implicit preference formulation (nonnegative).
145   * Default: 1.0
146   * @group param
147   */
148  val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference",
149    ParamValidators.gtEq(0))
150
151  /** @group getParam */
152  def getAlpha: Double = $(alpha)
153
154  /**
155   * Param for the column name for ratings.
156   * Default: "rating"
157   * @group param
158   */
159  val ratingCol = new Param[String](this, "ratingCol", "column name for ratings")
160
161  /** @group getParam */
162  def getRatingCol: String = $(ratingCol)
163
164  /**
165   * Param for whether to apply nonnegativity constraints.
166   * Default: false
167   * @group param
168   */
169  val nonnegative = new BooleanParam(
170    this, "nonnegative", "whether to use nonnegative constraint for least squares")
171
172  /** @group getParam */
173  def getNonnegative: Boolean = $(nonnegative)
174
175  /**
176   * Param for StorageLevel for intermediate datasets. Pass in a string representation of
177   * `StorageLevel`. Cannot be "NONE".
178   * Default: "MEMORY_AND_DISK".
179   *
180   * @group expertParam
181   */
182  val intermediateStorageLevel = new Param[String](this, "intermediateStorageLevel",
183    "StorageLevel for intermediate datasets. Cannot be 'NONE'.",
184    (s: String) => Try(StorageLevel.fromString(s)).isSuccess && s != "NONE")
185
186  /** @group expertGetParam */
187  def getIntermediateStorageLevel: String = $(intermediateStorageLevel)
188
189  /**
190   * Param for StorageLevel for ALS model factors. Pass in a string representation of
191   * `StorageLevel`.
192   * Default: "MEMORY_AND_DISK".
193   *
194   * @group expertParam
195   */
196  val finalStorageLevel = new Param[String](this, "finalStorageLevel",
197    "StorageLevel for ALS model factors.",
198    (s: String) => Try(StorageLevel.fromString(s)).isSuccess)
199
200  /** @group expertGetParam */
201  def getFinalStorageLevel: String = $(finalStorageLevel)
202
203  setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
204    implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
205    ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
206    intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
207
208  /**
209   * Validates and transforms the input schema.
210   *
211   * @param schema input schema
212   * @return output schema
213   */
214  protected def validateAndTransformSchema(schema: StructType): StructType = {
215    // user and item will be cast to Int
216    SchemaUtils.checkNumericType(schema, $(userCol))
217    SchemaUtils.checkNumericType(schema, $(itemCol))
218    // rating will be cast to Float
219    SchemaUtils.checkNumericType(schema, $(ratingCol))
220    SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
221  }
222}
223
224/**
225 * Model fitted by ALS.
226 *
227 * @param rank rank of the matrix factorization model
228 * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features`
229 * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features`
230 */
231@Since("1.3.0")
232class ALSModel private[ml] (
233    @Since("1.4.0") override val uid: String,
234    @Since("1.4.0") val rank: Int,
235    @transient val userFactors: DataFrame,
236    @transient val itemFactors: DataFrame)
237  extends Model[ALSModel] with ALSModelParams with MLWritable {
238
239  /** @group setParam */
240  @Since("1.4.0")
241  def setUserCol(value: String): this.type = set(userCol, value)
242
243  /** @group setParam */
244  @Since("1.4.0")
245  def setItemCol(value: String): this.type = set(itemCol, value)
246
247  /** @group setParam */
248  @Since("1.3.0")
249  def setPredictionCol(value: String): this.type = set(predictionCol, value)
250
251  @Since("2.0.0")
252  override def transform(dataset: Dataset[_]): DataFrame = {
253    transformSchema(dataset.schema)
254    // Register a UDF for DataFrame, and then
255    // create a new column named map(predictionCol) by running the predict UDF.
256    val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
257      if (userFeatures != null && itemFeatures != null) {
258        blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
259      } else {
260        Float.NaN
261      }
262    }
263    dataset
264      .join(userFactors,
265        checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
266      .join(itemFactors,
267        checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
268      .select(dataset("*"),
269        predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
270  }
271
272  @Since("1.3.0")
273  override def transformSchema(schema: StructType): StructType = {
274    // user and item will be cast to Int
275    SchemaUtils.checkNumericType(schema, $(userCol))
276    SchemaUtils.checkNumericType(schema, $(itemCol))
277    SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
278  }
279
280  @Since("1.5.0")
281  override def copy(extra: ParamMap): ALSModel = {
282    val copied = new ALSModel(uid, rank, userFactors, itemFactors)
283    copyValues(copied, extra).setParent(parent)
284  }
285
286  @Since("1.6.0")
287  override def write: MLWriter = new ALSModel.ALSModelWriter(this)
288}
289
290@Since("1.6.0")
291object ALSModel extends MLReadable[ALSModel] {
292
293  @Since("1.6.0")
294  override def read: MLReader[ALSModel] = new ALSModelReader
295
296  @Since("1.6.0")
297  override def load(path: String): ALSModel = super.load(path)
298
299  private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
300
301    override protected def saveImpl(path: String): Unit = {
302      val extraMetadata = "rank" -> instance.rank
303      DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
304      val userPath = new Path(path, "userFactors").toString
305      instance.userFactors.write.format("parquet").save(userPath)
306      val itemPath = new Path(path, "itemFactors").toString
307      instance.itemFactors.write.format("parquet").save(itemPath)
308    }
309  }
310
311  private class ALSModelReader extends MLReader[ALSModel] {
312
313    /** Checked against metadata when loading model */
314    private val className = classOf[ALSModel].getName
315
316    override def load(path: String): ALSModel = {
317      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
318      implicit val format = DefaultFormats
319      val rank = (metadata.metadata \ "rank").extract[Int]
320      val userPath = new Path(path, "userFactors").toString
321      val userFactors = sparkSession.read.format("parquet").load(userPath)
322      val itemPath = new Path(path, "itemFactors").toString
323      val itemFactors = sparkSession.read.format("parquet").load(itemPath)
324
325      val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors)
326
327      DefaultParamsReader.getAndSetParams(model, metadata)
328      model
329    }
330  }
331}
332
333/**
334 * Alternating Least Squares (ALS) matrix factorization.
335 *
336 * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
337 * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices.
338 * The general approach is iterative. During each iteration, one of the factor matrices is held
339 * constant, while the other is solved for using least squares. The newly-solved factor matrix is
340 * then held constant while solving for the other factor matrix.
341 *
342 * This is a blocked implementation of the ALS factorization algorithm that groups the two sets
343 * of factors (referred to as "users" and "products") into blocks and reduces communication by only
344 * sending one copy of each user vector to each product block on each iteration, and only for the
345 * product blocks that need that user's feature vector. This is achieved by pre-computing some
346 * information about the ratings matrix to determine the "out-links" of each user (which blocks of
347 * products it will contribute to) and "in-link" information for each product (which of the feature
348 * vectors it receives from each user block it will depend on). This allows us to send only an
349 * array of feature vectors between each user block and product block, and have the product block
350 * find the users' ratings and update the products based on these messages.
351 *
352 * For implicit preference data, the algorithm used is based on
353 * "Collaborative Filtering for Implicit Feedback Datasets", available at
354 * http://dx.doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here.
355 *
356 * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
357 * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
358 * r is greater than 0 and 0 if r is less than or equal to 0. The ratings then act as 'confidence'
359 * values related to strength of indicated user
360 * preferences rather than explicit ratings given to items.
361 */
362@Since("1.3.0")
363class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams
364  with DefaultParamsWritable {
365
366  import org.apache.spark.ml.recommendation.ALS.Rating
367
368  @Since("1.4.0")
369  def this() = this(Identifiable.randomUID("als"))
370
371  /** @group setParam */
372  @Since("1.3.0")
373  def setRank(value: Int): this.type = set(rank, value)
374
375  /** @group setParam */
376  @Since("1.3.0")
377  def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value)
378
379  /** @group setParam */
380  @Since("1.3.0")
381  def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value)
382
383  /** @group setParam */
384  @Since("1.3.0")
385  def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value)
386
387  /** @group setParam */
388  @Since("1.3.0")
389  def setAlpha(value: Double): this.type = set(alpha, value)
390
391  /** @group setParam */
392  @Since("1.3.0")
393  def setUserCol(value: String): this.type = set(userCol, value)
394
395  /** @group setParam */
396  @Since("1.3.0")
397  def setItemCol(value: String): this.type = set(itemCol, value)
398
399  /** @group setParam */
400  @Since("1.3.0")
401  def setRatingCol(value: String): this.type = set(ratingCol, value)
402
403  /** @group setParam */
404  @Since("1.3.0")
405  def setPredictionCol(value: String): this.type = set(predictionCol, value)
406
407  /** @group setParam */
408  @Since("1.3.0")
409  def setMaxIter(value: Int): this.type = set(maxIter, value)
410
411  /** @group setParam */
412  @Since("1.3.0")
413  def setRegParam(value: Double): this.type = set(regParam, value)
414
415  /** @group setParam */
416  @Since("1.3.0")
417  def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
418
419  /** @group setParam */
420  @Since("1.4.0")
421  def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
422
423  /** @group setParam */
424  @Since("1.3.0")
425  def setSeed(value: Long): this.type = set(seed, value)
426
427  /** @group expertSetParam */
428  @Since("2.0.0")
429  def setIntermediateStorageLevel(value: String): this.type = set(intermediateStorageLevel, value)
430
431  /** @group expertSetParam */
432  @Since("2.0.0")
433  def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)
434
435  /**
436   * Sets both numUserBlocks and numItemBlocks to the specific value.
437   *
438   * @group setParam
439   */
440  @Since("1.3.0")
441  def setNumBlocks(value: Int): this.type = {
442    setNumUserBlocks(value)
443    setNumItemBlocks(value)
444    this
445  }
446
447  @Since("2.0.0")
448  override def fit(dataset: Dataset[_]): ALSModel = {
449    transformSchema(dataset.schema)
450    import dataset.sparkSession.implicits._
451
452    val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
453    val ratings = dataset
454      .select(checkedCast(col($(userCol)).cast(DoubleType)),
455        checkedCast(col($(itemCol)).cast(DoubleType)), r)
456      .rdd
457      .map { row =>
458        Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
459      }
460    val instrLog = Instrumentation.create(this, ratings)
461    instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
462                       userCol, itemCol, ratingCol, predictionCol, maxIter,
463                       regParam, nonnegative, checkpointInterval, seed)
464    val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
465      numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
466      maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
467      alpha = $(alpha), nonnegative = $(nonnegative),
468      intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)),
469      finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)),
470      checkpointInterval = $(checkpointInterval), seed = $(seed))
471    val userDF = userFactors.toDF("id", "features")
472    val itemDF = itemFactors.toDF("id", "features")
473    val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
474    instrLog.logSuccess(model)
475    copyValues(model)
476  }
477
478  @Since("1.3.0")
479  override def transformSchema(schema: StructType): StructType = {
480    validateAndTransformSchema(schema)
481  }
482
483  @Since("1.5.0")
484  override def copy(extra: ParamMap): ALS = defaultCopy(extra)
485}
486
487
488/**
489 * :: DeveloperApi ::
490 * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is
491 * exposed as a developer API for users who do need other ID types. But it is not recommended
492 * because it increases the shuffle size and memory requirement during training. For simplicity,
493 * users and items must have the same type. The number of distinct users/items should be smaller
494 * than 2 billion.
495 */
496@DeveloperApi
497object ALS extends DefaultParamsReadable[ALS] with Logging {
498
499  /**
500   * :: DeveloperApi ::
501   * Rating class for better code readability.
502   */
503  @DeveloperApi
504  case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
505
506  @Since("1.6.0")
507  override def load(path: String): ALS = super.load(path)
508
509  /** Trait for least squares solvers applied to the normal equation. */
510  private[recommendation] trait LeastSquaresNESolver extends Serializable {
511    /** Solves a least squares problem with regularization (possibly with other constraints). */
512    def solve(ne: NormalEquation, lambda: Double): Array[Float]
513  }
514
515  /** Cholesky solver for least square problems. */
516  private[recommendation] class CholeskySolver extends LeastSquaresNESolver {
517
518    /**
519     * Solves a least squares problem with L2 regularization:
520     *
521     *   min norm(A x - b)^2^ + lambda * norm(x)^2^
522     *
523     * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
524     * @param lambda regularization constant
525     * @return the solution x
526     */
527    override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
528      val k = ne.k
529      // Add scaled lambda to the diagonals of AtA.
530      var i = 0
531      var j = 2
532      while (i < ne.triK) {
533        ne.ata(i) += lambda
534        i += j
535        j += 1
536      }
537      CholeskyDecomposition.solve(ne.ata, ne.atb)
538      val x = new Array[Float](k)
539      i = 0
540      while (i < k) {
541        x(i) = ne.atb(i).toFloat
542        i += 1
543      }
544      ne.reset()
545      x
546    }
547  }
548
549  /** NNLS solver. */
550  private[recommendation] class NNLSSolver extends LeastSquaresNESolver {
551    private var rank: Int = -1
552    private var workspace: NNLS.Workspace = _
553    private var ata: Array[Double] = _
554    private var initialized: Boolean = false
555
556    private def initialize(rank: Int): Unit = {
557      if (!initialized) {
558        this.rank = rank
559        workspace = NNLS.createWorkspace(rank)
560        ata = new Array[Double](rank * rank)
561        initialized = true
562      } else {
563        require(this.rank == rank)
564      }
565    }
566
567    /**
568     * Solves a nonnegative least squares problem with L2 regularization:
569     *
570     *   min_x_  norm(A x - b)^2^ + lambda * n * norm(x)^2^
571     *   subject to x >= 0
572     */
573    override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
574      val rank = ne.k
575      initialize(rank)
576      fillAtA(ne.ata, lambda)
577      val x = NNLS.solve(ata, ne.atb, workspace)
578      ne.reset()
579      x.map(x => x.toFloat)
580    }
581
582    /**
583     * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
584     * matrix that it represents, storing it into destMatrix.
585     */
586    private def fillAtA(triAtA: Array[Double], lambda: Double) {
587      var i = 0
588      var pos = 0
589      var a = 0.0
590      while (i < rank) {
591        var j = 0
592        while (j <= i) {
593          a = triAtA(pos)
594          ata(i * rank + j) = a
595          ata(j * rank + i) = a
596          pos += 1
597          j += 1
598        }
599        ata(i * rank + i) += lambda
600        i += 1
601      }
602    }
603  }
604
605  /**
606   * Representing a normal equation to solve the following weighted least squares problem:
607   *
608   * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x.
609   *
610   * Its normal equation is given by
611   *
612   * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0.
613   */
614  private[recommendation] class NormalEquation(val k: Int) extends Serializable {
615
616    /** Number of entries in the upper triangular part of a k-by-k matrix. */
617    val triK = k * (k + 1) / 2
618    /** A^T^ * A */
619    val ata = new Array[Double](triK)
620    /** A^T^ * b */
621    val atb = new Array[Double](k)
622
623    private val da = new Array[Double](k)
624    private val upper = "U"
625
626    private def copyToDouble(a: Array[Float]): Unit = {
627      var i = 0
628      while (i < k) {
629        da(i) = a(i)
630        i += 1
631      }
632    }
633
634    /** Adds an observation. */
635    def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
636      require(c >= 0.0)
637      require(a.length == k)
638      copyToDouble(a)
639      blas.dspr(upper, k, c, da, 1, ata)
640      if (b != 0.0) {
641        blas.daxpy(k, c * b, da, 1, atb, 1)
642      }
643      this
644    }
645
646    /** Merges another normal equation object. */
647    def merge(other: NormalEquation): this.type = {
648      require(other.k == k)
649      blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
650      blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
651      this
652    }
653
654    /** Resets everything to zero, which should be called after each solve. */
655    def reset(): Unit = {
656      ju.Arrays.fill(ata, 0.0)
657      ju.Arrays.fill(atb, 0.0)
658    }
659  }
660
661  /**
662   * :: DeveloperApi ::
663   * Implementation of the ALS algorithm.
664   */
665  @DeveloperApi
666  def train[ID: ClassTag]( // scalastyle:ignore
667      ratings: RDD[Rating[ID]],
668      rank: Int = 10,
669      numUserBlocks: Int = 10,
670      numItemBlocks: Int = 10,
671      maxIter: Int = 10,
672      regParam: Double = 1.0,
673      implicitPrefs: Boolean = false,
674      alpha: Double = 1.0,
675      nonnegative: Boolean = false,
676      intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
677      finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
678      checkpointInterval: Int = 10,
679      seed: Long = 0L)(
680      implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
681    require(intermediateRDDStorageLevel != StorageLevel.NONE,
682      "ALS is not designed to run without persisting intermediate RDDs.")
683    val sc = ratings.sparkContext
684    val userPart = new ALSPartitioner(numUserBlocks)
685    val itemPart = new ALSPartitioner(numItemBlocks)
686    val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
687    val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
688    val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
689    val blockRatings = partitionRatings(ratings, userPart, itemPart)
690      .persist(intermediateRDDStorageLevel)
691    val (userInBlocks, userOutBlocks) =
692      makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel)
693    // materialize blockRatings and user blocks
694    userOutBlocks.count()
695    val swappedBlockRatings = blockRatings.map {
696      case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
697        ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
698    }
699    val (itemInBlocks, itemOutBlocks) =
700      makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel)
701    // materialize item blocks
702    itemOutBlocks.count()
703    val seedGen = new XORShiftRandom(seed)
704    var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
705    var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
706    var previousCheckpointFile: Option[String] = None
707    val shouldCheckpoint: Int => Boolean = (iter) =>
708      sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0)
709    val deletePreviousCheckpointFile: () => Unit = () =>
710      previousCheckpointFile.foreach { file =>
711        try {
712          val checkpointFile = new Path(file)
713          checkpointFile.getFileSystem(sc.hadoopConfiguration).delete(checkpointFile, true)
714        } catch {
715          case e: IOException =>
716            logWarning(s"Cannot delete checkpoint file $file:", e)
717        }
718      }
719    if (implicitPrefs) {
720      for (iter <- 1 to maxIter) {
721        userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
722        val previousItemFactors = itemFactors
723        itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
724          userLocalIndexEncoder, implicitPrefs, alpha, solver)
725        previousItemFactors.unpersist()
726        itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
727        // TODO: Generalize PeriodicGraphCheckpointer and use it here.
728        val deps = itemFactors.dependencies
729        if (shouldCheckpoint(iter)) {
730          itemFactors.checkpoint() // itemFactors gets materialized in computeFactors
731        }
732        val previousUserFactors = userFactors
733        userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
734          itemLocalIndexEncoder, implicitPrefs, alpha, solver)
735        if (shouldCheckpoint(iter)) {
736          ALS.cleanShuffleDependencies(sc, deps)
737          deletePreviousCheckpointFile()
738          previousCheckpointFile = itemFactors.getCheckpointFile
739        }
740        previousUserFactors.unpersist()
741      }
742    } else {
743      for (iter <- 0 until maxIter) {
744        itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
745          userLocalIndexEncoder, solver = solver)
746        if (shouldCheckpoint(iter)) {
747          val deps = itemFactors.dependencies
748          itemFactors.checkpoint()
749          itemFactors.count() // checkpoint item factors and cut lineage
750          ALS.cleanShuffleDependencies(sc, deps)
751          deletePreviousCheckpointFile()
752          previousCheckpointFile = itemFactors.getCheckpointFile
753        }
754        userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
755          itemLocalIndexEncoder, solver = solver)
756      }
757    }
758    val userIdAndFactors = userInBlocks
759      .mapValues(_.srcIds)
760      .join(userFactors)
761      .mapPartitions({ items =>
762        items.flatMap { case (_, (ids, factors)) =>
763          ids.view.zip(factors)
764        }
765      // Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks
766      // and userFactors.
767      }, preservesPartitioning = true)
768      .setName("userFactors")
769      .persist(finalRDDStorageLevel)
770    val itemIdAndFactors = itemInBlocks
771      .mapValues(_.srcIds)
772      .join(itemFactors)
773      .mapPartitions({ items =>
774        items.flatMap { case (_, (ids, factors)) =>
775          ids.view.zip(factors)
776        }
777      }, preservesPartitioning = true)
778      .setName("itemFactors")
779      .persist(finalRDDStorageLevel)
780    if (finalRDDStorageLevel != StorageLevel.NONE) {
781      userIdAndFactors.count()
782      itemFactors.unpersist()
783      itemIdAndFactors.count()
784      userInBlocks.unpersist()
785      userOutBlocks.unpersist()
786      itemInBlocks.unpersist()
787      itemOutBlocks.unpersist()
788      blockRatings.unpersist()
789    }
790    (userIdAndFactors, itemIdAndFactors)
791  }
792
793  /**
794   * Factor block that stores factors (Array[Float]) in an Array.
795   */
796  private type FactorBlock = Array[Array[Float]]
797
798  /**
799   * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to
800   * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the
801   * src factors in this block to send to dst block 0.
802   */
803  private type OutBlock = Array[Array[Int]]
804
805  /**
806   * In-link block for computing src (user/item) factors. This includes the original src IDs
807   * of the elements within this block as well as encoded dst (item/user) indices and corresponding
808   * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original
809   * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices.
810   * For example, if we have an in-link record
811   *
812   * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0},
813   *
814   * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which
815   * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3).
816   *
817   * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can
818   * compute src factors one after another using only one normal equation instance.
819   *
820   * @param srcIds src ids (ordered)
821   * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and
822   *                ratings are associated with srcIds(i).
823   * @param dstEncodedIndices encoded dst indices
824   * @param ratings ratings
825   * @see [[LocalIndexEncoder]]
826   */
827  private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
828      srcIds: Array[ID],
829      dstPtrs: Array[Int],
830      dstEncodedIndices: Array[Int],
831      ratings: Array[Float]) {
832    /** Size of the block. */
833    def size: Int = ratings.length
834    require(dstEncodedIndices.length == size)
835    require(dstPtrs.length == srcIds.length + 1)
836  }
837
838  /**
839   * Initializes factors randomly given the in-link blocks.
840   *
841   * @param inBlocks in-link blocks
842   * @param rank rank
843   * @return initialized factor blocks
844   */
845  private def initialize[ID](
846      inBlocks: RDD[(Int, InBlock[ID])],
847      rank: Int,
848      seed: Long): RDD[(Int, FactorBlock)] = {
849    // Choose a unit vector uniformly at random from the unit sphere, but from the
850    // "first quadrant" where all elements are nonnegative. This can be done by choosing
851    // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
852    // This appears to create factorizations that have a slightly better reconstruction
853    // (<1%) compared picking elements uniformly at random in [0,1].
854    inBlocks.map { case (srcBlockId, inBlock) =>
855      val random = new XORShiftRandom(byteswap64(seed ^ srcBlockId))
856      val factors = Array.fill(inBlock.srcIds.length) {
857        val factor = Array.fill(rank)(random.nextGaussian().toFloat)
858        val nrm = blas.snrm2(rank, factor, 1)
859        blas.sscal(rank, 1.0f / nrm, factor, 1)
860        factor
861      }
862      (srcBlockId, factors)
863    }
864  }
865
866  /**
867   * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
868   */
869  private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag](
870      srcIds: Array[ID],
871      dstIds: Array[ID],
872      ratings: Array[Float]) {
873    /** Size of the block. */
874    def size: Int = srcIds.length
875    require(dstIds.length == srcIds.length)
876    require(ratings.length == srcIds.length)
877  }
878
879  /**
880   * Builder for [[RatingBlock]]. `mutable.ArrayBuilder` is used to avoid boxing/unboxing.
881   */
882  private[recommendation] class RatingBlockBuilder[@specialized(Int, Long) ID: ClassTag]
883    extends Serializable {
884
885    private val srcIds = mutable.ArrayBuilder.make[ID]
886    private val dstIds = mutable.ArrayBuilder.make[ID]
887    private val ratings = mutable.ArrayBuilder.make[Float]
888    var size = 0
889
890    /** Adds a rating. */
891    def add(r: Rating[ID]): this.type = {
892      size += 1
893      srcIds += r.user
894      dstIds += r.item
895      ratings += r.rating
896      this
897    }
898
899    /** Merges another [[RatingBlockBuilder]]. */
900    def merge(other: RatingBlock[ID]): this.type = {
901      size += other.srcIds.length
902      srcIds ++= other.srcIds
903      dstIds ++= other.dstIds
904      ratings ++= other.ratings
905      this
906    }
907
908    /** Builds a [[RatingBlock]]. */
909    def build(): RatingBlock[ID] = {
910      RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result())
911    }
912  }
913
914  /**
915   * Partitions raw ratings into blocks.
916   *
917   * @param ratings raw ratings
918   * @param srcPart partitioner for src IDs
919   * @param dstPart partitioner for dst IDs
920   * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
921   */
922  private def partitionRatings[ID: ClassTag](
923      ratings: RDD[Rating[ID]],
924      srcPart: Partitioner,
925      dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = {
926
927     /* The implementation produces the same result as the following but generates less objects.
928
929     ratings.map { r =>
930       ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r)
931     }.aggregateByKey(new RatingBlockBuilder)(
932         seqOp = (b, r) => b.add(r),
933         combOp = (b0, b1) => b0.merge(b1.build()))
934       .mapValues(_.build())
935     */
936
937    val numPartitions = srcPart.numPartitions * dstPart.numPartitions
938    ratings.mapPartitions { iter =>
939      val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID])
940      iter.flatMap { r =>
941        val srcBlockId = srcPart.getPartition(r.user)
942        val dstBlockId = dstPart.getPartition(r.item)
943        val idx = srcBlockId + srcPart.numPartitions * dstBlockId
944        val builder = builders(idx)
945        builder.add(r)
946        if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k
947          builders(idx) = new RatingBlockBuilder
948          Iterator.single(((srcBlockId, dstBlockId), builder.build()))
949        } else {
950          Iterator.empty
951        }
952      } ++ {
953        builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) =>
954          val srcBlockId = idx % srcPart.numPartitions
955          val dstBlockId = idx / srcPart.numPartitions
956          ((srcBlockId, dstBlockId), block.build())
957        }
958      }
959    }.groupByKey().mapValues { blocks =>
960      val builder = new RatingBlockBuilder[ID]
961      blocks.foreach(builder.merge)
962      builder.build()
963    }.setName("ratingBlocks")
964  }
965
966  /**
967   * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
968   *
969   * @param encoder encoder for dst indices
970   */
971  private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
972      encoder: LocalIndexEncoder)(
973      implicit ord: Ordering[ID]) {
974
975    private val srcIds = mutable.ArrayBuilder.make[ID]
976    private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
977    private val ratings = mutable.ArrayBuilder.make[Float]
978
979    /**
980     * Adds a dst block of (srcId, dstLocalIndex, rating) tuples.
981     *
982     * @param dstBlockId dst block ID
983     * @param srcIds original src IDs
984     * @param dstLocalIndices dst local indices
985     * @param ratings ratings
986     */
987    def add(
988        dstBlockId: Int,
989        srcIds: Array[ID],
990        dstLocalIndices: Array[Int],
991        ratings: Array[Float]): this.type = {
992      val sz = srcIds.length
993      require(dstLocalIndices.length == sz)
994      require(ratings.length == sz)
995      this.srcIds ++= srcIds
996      this.ratings ++= ratings
997      var j = 0
998      while (j < sz) {
999        this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j))
1000        j += 1
1001      }
1002      this
1003    }
1004
1005    /** Builds a [[UncompressedInBlock]]. */
1006    def build(): UncompressedInBlock[ID] = {
1007      new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
1008    }
1009  }
1010
1011  /**
1012   * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
1013   */
1014  private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag](
1015      val srcIds: Array[ID],
1016      val dstEncodedIndices: Array[Int],
1017      val ratings: Array[Float])(
1018      implicit ord: Ordering[ID]) {
1019
1020    /** Size the of block. */
1021    def length: Int = srcIds.length
1022
1023    /**
1024     * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a
1025     * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
1026     * Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
1027     */
1028    def compress(): InBlock[ID] = {
1029      val sz = length
1030      assert(sz > 0, "Empty in-link block should not exist.")
1031      sort()
1032      val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID]
1033      val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
1034      var preSrcId = srcIds(0)
1035      uniqueSrcIdsBuilder += preSrcId
1036      var curCount = 1
1037      var i = 1
1038      var j = 0
1039      while (i < sz) {
1040        val srcId = srcIds(i)
1041        if (srcId != preSrcId) {
1042          uniqueSrcIdsBuilder += srcId
1043          dstCountsBuilder += curCount
1044          preSrcId = srcId
1045          j += 1
1046          curCount = 0
1047        }
1048        curCount += 1
1049        i += 1
1050      }
1051      dstCountsBuilder += curCount
1052      val uniqueSrcIds = uniqueSrcIdsBuilder.result()
1053      val numUniqueSrdIds = uniqueSrcIds.length
1054      val dstCounts = dstCountsBuilder.result()
1055      val dstPtrs = new Array[Int](numUniqueSrdIds + 1)
1056      var sum = 0
1057      i = 0
1058      while (i < numUniqueSrdIds) {
1059        sum += dstCounts(i)
1060        i += 1
1061        dstPtrs(i) = sum
1062      }
1063      InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings)
1064    }
1065
1066    private def sort(): Unit = {
1067      val sz = length
1068      // Since there might be interleaved log messages, we insert a unique id for easy pairing.
1069      val sortId = Utils.random.nextInt()
1070      logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
1071      val start = System.nanoTime()
1072      val sorter = new Sorter(new UncompressedInBlockSort[ID])
1073      sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]])
1074      val duration = (System.nanoTime() - start) / 1e9
1075      logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
1076    }
1077  }
1078
1079  /**
1080   * A wrapper that holds a primitive key.
1081   *
1082   * @see [[UncompressedInBlockSort]]
1083   */
1084  private class KeyWrapper[@specialized(Int, Long) ID: ClassTag](
1085      implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] {
1086
1087    var key: ID = _
1088
1089    override def compare(that: KeyWrapper[ID]): Int = {
1090      ord.compare(key, that.key)
1091    }
1092
1093    def setKey(key: ID): this.type = {
1094      this.key = key
1095      this
1096    }
1097  }
1098
1099  /**
1100   * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
1101   */
1102  private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag](
1103      implicit ord: Ordering[ID])
1104    extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] {
1105
1106    override def newKey(): KeyWrapper[ID] = new KeyWrapper()
1107
1108    override def getKey(
1109        data: UncompressedInBlock[ID],
1110        pos: Int,
1111        reuse: KeyWrapper[ID]): KeyWrapper[ID] = {
1112      if (reuse == null) {
1113        new KeyWrapper().setKey(data.srcIds(pos))
1114      } else {
1115        reuse.setKey(data.srcIds(pos))
1116      }
1117    }
1118
1119    override def getKey(
1120        data: UncompressedInBlock[ID],
1121        pos: Int): KeyWrapper[ID] = {
1122      getKey(data, pos, null)
1123    }
1124
1125    private def swapElements[@specialized(Int, Float) T](
1126        data: Array[T],
1127        pos0: Int,
1128        pos1: Int): Unit = {
1129      val tmp = data(pos0)
1130      data(pos0) = data(pos1)
1131      data(pos1) = tmp
1132    }
1133
1134    override def swap(data: UncompressedInBlock[ID], pos0: Int, pos1: Int): Unit = {
1135      swapElements(data.srcIds, pos0, pos1)
1136      swapElements(data.dstEncodedIndices, pos0, pos1)
1137      swapElements(data.ratings, pos0, pos1)
1138    }
1139
1140    override def copyRange(
1141        src: UncompressedInBlock[ID],
1142        srcPos: Int,
1143        dst: UncompressedInBlock[ID],
1144        dstPos: Int,
1145        length: Int): Unit = {
1146      System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
1147      System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length)
1148      System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
1149    }
1150
1151    override def allocate(length: Int): UncompressedInBlock[ID] = {
1152      new UncompressedInBlock(
1153        new Array[ID](length), new Array[Int](length), new Array[Float](length))
1154    }
1155
1156    override def copyElement(
1157        src: UncompressedInBlock[ID],
1158        srcPos: Int,
1159        dst: UncompressedInBlock[ID],
1160        dstPos: Int): Unit = {
1161      dst.srcIds(dstPos) = src.srcIds(srcPos)
1162      dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
1163      dst.ratings(dstPos) = src.ratings(srcPos)
1164    }
1165  }
1166
1167  /**
1168   * Creates in-blocks and out-blocks from rating blocks.
1169   *
1170   * @param prefix prefix for in/out-block names
1171   * @param ratingBlocks rating blocks
1172   * @param srcPart partitioner for src IDs
1173   * @param dstPart partitioner for dst IDs
1174   * @return (in-blocks, out-blocks)
1175   */
1176  private def makeBlocks[ID: ClassTag](
1177      prefix: String,
1178      ratingBlocks: RDD[((Int, Int), RatingBlock[ID])],
1179      srcPart: Partitioner,
1180      dstPart: Partitioner,
1181      storageLevel: StorageLevel)(
1182      implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = {
1183    val inBlocks = ratingBlocks.map {
1184      case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
1185        // The implementation is a faster version of
1186        // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
1187        val start = System.nanoTime()
1188        val dstIdSet = new OpenHashSet[ID](1 << 20)
1189        dstIds.foreach(dstIdSet.add)
1190        val sortedDstIds = new Array[ID](dstIdSet.size)
1191        var i = 0
1192        var pos = dstIdSet.nextPos(0)
1193        while (pos != -1) {
1194          sortedDstIds(i) = dstIdSet.getValue(pos)
1195          pos = dstIdSet.nextPos(pos + 1)
1196          i += 1
1197        }
1198        assert(i == dstIdSet.size)
1199        Sorting.quickSort(sortedDstIds)
1200        val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length)
1201        i = 0
1202        while (i < sortedDstIds.length) {
1203          dstIdToLocalIndex.update(sortedDstIds(i), i)
1204          i += 1
1205        }
1206        logDebug(
1207          "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
1208        val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
1209        (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
1210    }.groupByKey(new ALSPartitioner(srcPart.numPartitions))
1211      .mapValues { iter =>
1212        val builder =
1213          new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions))
1214        iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
1215          builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
1216        }
1217        builder.build().compress()
1218      }.setName(prefix + "InBlocks")
1219      .persist(storageLevel)
1220    val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
1221      val encoder = new LocalIndexEncoder(dstPart.numPartitions)
1222      val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
1223      var i = 0
1224      val seen = new Array[Boolean](dstPart.numPartitions)
1225      while (i < srcIds.length) {
1226        var j = dstPtrs(i)
1227        ju.Arrays.fill(seen, false)
1228        while (j < dstPtrs(i + 1)) {
1229          val dstBlockId = encoder.blockId(dstEncodedIndices(j))
1230          if (!seen(dstBlockId)) {
1231            activeIds(dstBlockId) += i // add the local index in this out-block
1232            seen(dstBlockId) = true
1233          }
1234          j += 1
1235        }
1236        i += 1
1237      }
1238      activeIds.map { x =>
1239        x.result()
1240      }
1241    }.setName(prefix + "OutBlocks")
1242      .persist(storageLevel)
1243    (inBlocks, outBlocks)
1244  }
1245
1246  /**
1247   * Compute dst factors by constructing and solving least square problems.
1248   *
1249   * @param srcFactorBlocks src factors
1250   * @param srcOutBlocks src out-blocks
1251   * @param dstInBlocks dst in-blocks
1252   * @param rank rank
1253   * @param regParam regularization constant
1254   * @param srcEncoder encoder for src local indices
1255   * @param implicitPrefs whether to use implicit preference
1256   * @param alpha the alpha constant in the implicit preference formulation
1257   * @param solver solver for least squares problems
1258   * @return dst factors
1259   */
1260  private def computeFactors[ID](
1261      srcFactorBlocks: RDD[(Int, FactorBlock)],
1262      srcOutBlocks: RDD[(Int, OutBlock)],
1263      dstInBlocks: RDD[(Int, InBlock[ID])],
1264      rank: Int,
1265      regParam: Double,
1266      srcEncoder: LocalIndexEncoder,
1267      implicitPrefs: Boolean = false,
1268      alpha: Double = 1.0,
1269      solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
1270    val numSrcBlocks = srcFactorBlocks.partitions.length
1271    val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
1272    val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
1273      case (srcBlockId, (srcOutBlock, srcFactors)) =>
1274        srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) =>
1275          (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
1276        }
1277    }
1278    val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length))
1279    dstInBlocks.join(merged).mapValues {
1280      case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
1281        val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
1282        srcFactors.foreach { case (srcBlockId, factors) =>
1283          sortedSrcFactors(srcBlockId) = factors
1284        }
1285        val dstFactors = new Array[Array[Float]](dstIds.length)
1286        var j = 0
1287        val ls = new NormalEquation(rank)
1288        while (j < dstIds.length) {
1289          ls.reset()
1290          if (implicitPrefs) {
1291            ls.merge(YtY.get)
1292          }
1293          var i = srcPtrs(j)
1294          var numExplicits = 0
1295          while (i < srcPtrs(j + 1)) {
1296            val encoded = srcEncodedIndices(i)
1297            val blockId = srcEncoder.blockId(encoded)
1298            val localIndex = srcEncoder.localIndex(encoded)
1299            val srcFactor = sortedSrcFactors(blockId)(localIndex)
1300            val rating = ratings(i)
1301            if (implicitPrefs) {
1302              // Extension to the original paper to handle b < 0. confidence is a function of |b|
1303              // instead so that it is never negative. c1 is confidence - 1.0.
1304              val c1 = alpha * math.abs(rating)
1305              // For rating <= 0, the corresponding preference is 0. So the term below is only added
1306              // for rating > 0. Because YtY is already added, we need to adjust the scaling here.
1307              if (rating > 0) {
1308                numExplicits += 1
1309                ls.add(srcFactor, (c1 + 1.0) / c1, c1)
1310              }
1311            } else {
1312              ls.add(srcFactor, rating)
1313              numExplicits += 1
1314            }
1315            i += 1
1316          }
1317          // Weight lambda by the number of explicit ratings based on the ALS-WR paper.
1318          dstFactors(j) = solver.solve(ls, numExplicits * regParam)
1319          j += 1
1320        }
1321        dstFactors
1322    }
1323  }
1324
1325  /**
1326   * Computes the Gramian matrix of user or item factors, which is only used in implicit preference.
1327   * Caching of the input factors is handled in [[ALS#train]].
1328   */
1329  private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = {
1330    factorBlocks.values.aggregate(new NormalEquation(rank))(
1331      seqOp = (ne, factors) => {
1332        factors.foreach(ne.add(_, 0.0))
1333        ne
1334      },
1335      combOp = (ne1, ne2) => ne1.merge(ne2))
1336  }
1337
1338  /**
1339   * Encoder for storing (blockId, localIndex) into a single integer.
1340   *
1341   * We use the leading bits (including the sign bit) to store the block id and the rest to store
1342   * the local index. This is based on the assumption that users/items are approximately evenly
1343   * partitioned. With this assumption, we should be able to encode two billion distinct values.
1344   *
1345   * @param numBlocks number of blocks
1346   */
1347  private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable {
1348
1349    require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.")
1350
1351    private[this] final val numLocalIndexBits =
1352      math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31)
1353    private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1
1354
1355    /** Encodes a (blockId, localIndex) into a single integer. */
1356    def encode(blockId: Int, localIndex: Int): Int = {
1357      require(blockId < numBlocks)
1358      require((localIndex & ~localIndexMask) == 0)
1359      (blockId << numLocalIndexBits) | localIndex
1360    }
1361
1362    /** Gets the block id from an encoded index. */
1363    @inline
1364    def blockId(encoded: Int): Int = {
1365      encoded >>> numLocalIndexBits
1366    }
1367
1368    /** Gets the local index from an encoded index. */
1369    @inline
1370    def localIndex(encoded: Int): Int = {
1371      encoded & localIndexMask
1372    }
1373  }
1374
1375  /**
1376   * Partitioner used by ALS. We require that getPartition is a projection. That is, for any key k,
1377   * we have getPartition(getPartition(k)) = getPartition(k). Since the default HashPartitioner
1378   * satisfies this requirement, we simply use a type alias here.
1379   */
1380  private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner
1381
1382  /**
1383   * Private function to clean up all of the shuffles files from the dependencies and their parents.
1384   */
1385  private[spark] def cleanShuffleDependencies[T](
1386      sc: SparkContext,
1387      deps: Seq[Dependency[_]],
1388      blocking: Boolean = false): Unit = {
1389    // If there is no reference tracking we skip clean up.
1390    sc.cleaner.foreach { cleaner =>
1391      /**
1392       * Clean the shuffles & all of its parents.
1393       */
1394      def cleanEagerly(dep: Dependency[_]): Unit = {
1395        if (dep.isInstanceOf[ShuffleDependency[_, _, _]]) {
1396          val shuffleId = dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
1397          cleaner.doCleanupShuffle(shuffleId, blocking)
1398        }
1399        val rdd = dep.rdd
1400        val rddDeps = rdd.dependencies
1401        if (rdd.getStorageLevel == StorageLevel.NONE && rddDeps != null) {
1402          rddDeps.foreach(cleanEagerly)
1403        }
1404      }
1405      deps.foreach(cleanEagerly)
1406    }
1407  }
1408}
1409