1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.ml.recommendation 19 20import java.{util => ju} 21import java.io.IOException 22 23import scala.collection.mutable 24import scala.reflect.ClassTag 25import scala.util.{Sorting, Try} 26import scala.util.hashing.byteswap64 27 28import com.github.fommil.netlib.BLAS.{getInstance => blas} 29import org.apache.hadoop.fs.Path 30import org.json4s.DefaultFormats 31import org.json4s.JsonDSL._ 32 33import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContext} 34import org.apache.spark.annotation.{DeveloperApi, Since} 35import org.apache.spark.internal.Logging 36import org.apache.spark.ml.{Estimator, Model} 37import org.apache.spark.ml.param._ 38import org.apache.spark.ml.param.shared._ 39import org.apache.spark.ml.util._ 40import org.apache.spark.mllib.linalg.CholeskyDecomposition 41import org.apache.spark.mllib.optimization.NNLS 42import org.apache.spark.rdd.RDD 43import org.apache.spark.sql.{DataFrame, Dataset} 44import org.apache.spark.sql.functions._ 45import org.apache.spark.sql.types._ 46import org.apache.spark.storage.StorageLevel 47import org.apache.spark.util.Utils 48import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} 49import org.apache.spark.util.random.XORShiftRandom 50 51/** 52 * Common params for ALS and ALSModel. 53 */ 54private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { 55 /** 56 * Param for the column name for user ids. Ids must be integers. Other 57 * numeric types are supported for this column, but will be cast to integers as long as they 58 * fall within the integer value range. 59 * Default: "user" 60 * @group param 61 */ 62 val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " + 63 "the integer value range.") 64 65 /** @group getParam */ 66 def getUserCol: String = $(userCol) 67 68 /** 69 * Param for the column name for item ids. Ids must be integers. Other 70 * numeric types are supported for this column, but will be cast to integers as long as they 71 * fall within the integer value range. 72 * Default: "item" 73 * @group param 74 */ 75 val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " + 76 "the integer value range.") 77 78 /** @group getParam */ 79 def getItemCol: String = $(itemCol) 80 81 /** 82 * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is 83 * out of integer range. 84 */ 85 protected val checkedCast = udf { (n: Double) => 86 if (n > Int.MaxValue || n < Int.MinValue) { 87 throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + 88 s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") 89 } else { 90 n.toInt 91 } 92 } 93} 94 95/** 96 * Common params for ALS. 97 */ 98private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam 99 with HasPredictionCol with HasCheckpointInterval with HasSeed { 100 101 /** 102 * Param for rank of the matrix factorization (positive). 103 * Default: 10 104 * @group param 105 */ 106 val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1)) 107 108 /** @group getParam */ 109 def getRank: Int = $(rank) 110 111 /** 112 * Param for number of user blocks (positive). 113 * Default: 10 114 * @group param 115 */ 116 val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", 117 ParamValidators.gtEq(1)) 118 119 /** @group getParam */ 120 def getNumUserBlocks: Int = $(numUserBlocks) 121 122 /** 123 * Param for number of item blocks (positive). 124 * Default: 10 125 * @group param 126 */ 127 val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks", 128 ParamValidators.gtEq(1)) 129 130 /** @group getParam */ 131 def getNumItemBlocks: Int = $(numItemBlocks) 132 133 /** 134 * Param to decide whether to use implicit preference. 135 * Default: false 136 * @group param 137 */ 138 val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") 139 140 /** @group getParam */ 141 def getImplicitPrefs: Boolean = $(implicitPrefs) 142 143 /** 144 * Param for the alpha parameter in the implicit preference formulation (nonnegative). 145 * Default: 1.0 146 * @group param 147 */ 148 val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", 149 ParamValidators.gtEq(0)) 150 151 /** @group getParam */ 152 def getAlpha: Double = $(alpha) 153 154 /** 155 * Param for the column name for ratings. 156 * Default: "rating" 157 * @group param 158 */ 159 val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") 160 161 /** @group getParam */ 162 def getRatingCol: String = $(ratingCol) 163 164 /** 165 * Param for whether to apply nonnegativity constraints. 166 * Default: false 167 * @group param 168 */ 169 val nonnegative = new BooleanParam( 170 this, "nonnegative", "whether to use nonnegative constraint for least squares") 171 172 /** @group getParam */ 173 def getNonnegative: Boolean = $(nonnegative) 174 175 /** 176 * Param for StorageLevel for intermediate datasets. Pass in a string representation of 177 * `StorageLevel`. Cannot be "NONE". 178 * Default: "MEMORY_AND_DISK". 179 * 180 * @group expertParam 181 */ 182 val intermediateStorageLevel = new Param[String](this, "intermediateStorageLevel", 183 "StorageLevel for intermediate datasets. Cannot be 'NONE'.", 184 (s: String) => Try(StorageLevel.fromString(s)).isSuccess && s != "NONE") 185 186 /** @group expertGetParam */ 187 def getIntermediateStorageLevel: String = $(intermediateStorageLevel) 188 189 /** 190 * Param for StorageLevel for ALS model factors. Pass in a string representation of 191 * `StorageLevel`. 192 * Default: "MEMORY_AND_DISK". 193 * 194 * @group expertParam 195 */ 196 val finalStorageLevel = new Param[String](this, "finalStorageLevel", 197 "StorageLevel for ALS model factors.", 198 (s: String) => Try(StorageLevel.fromString(s)).isSuccess) 199 200 /** @group expertGetParam */ 201 def getFinalStorageLevel: String = $(finalStorageLevel) 202 203 setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, 204 implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", 205 ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, 206 intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK") 207 208 /** 209 * Validates and transforms the input schema. 210 * 211 * @param schema input schema 212 * @return output schema 213 */ 214 protected def validateAndTransformSchema(schema: StructType): StructType = { 215 // user and item will be cast to Int 216 SchemaUtils.checkNumericType(schema, $(userCol)) 217 SchemaUtils.checkNumericType(schema, $(itemCol)) 218 // rating will be cast to Float 219 SchemaUtils.checkNumericType(schema, $(ratingCol)) 220 SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) 221 } 222} 223 224/** 225 * Model fitted by ALS. 226 * 227 * @param rank rank of the matrix factorization model 228 * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features` 229 * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` 230 */ 231@Since("1.3.0") 232class ALSModel private[ml] ( 233 @Since("1.4.0") override val uid: String, 234 @Since("1.4.0") val rank: Int, 235 @transient val userFactors: DataFrame, 236 @transient val itemFactors: DataFrame) 237 extends Model[ALSModel] with ALSModelParams with MLWritable { 238 239 /** @group setParam */ 240 @Since("1.4.0") 241 def setUserCol(value: String): this.type = set(userCol, value) 242 243 /** @group setParam */ 244 @Since("1.4.0") 245 def setItemCol(value: String): this.type = set(itemCol, value) 246 247 /** @group setParam */ 248 @Since("1.3.0") 249 def setPredictionCol(value: String): this.type = set(predictionCol, value) 250 251 @Since("2.0.0") 252 override def transform(dataset: Dataset[_]): DataFrame = { 253 transformSchema(dataset.schema) 254 // Register a UDF for DataFrame, and then 255 // create a new column named map(predictionCol) by running the predict UDF. 256 val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => 257 if (userFeatures != null && itemFeatures != null) { 258 blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) 259 } else { 260 Float.NaN 261 } 262 } 263 dataset 264 .join(userFactors, 265 checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") 266 .join(itemFactors, 267 checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") 268 .select(dataset("*"), 269 predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) 270 } 271 272 @Since("1.3.0") 273 override def transformSchema(schema: StructType): StructType = { 274 // user and item will be cast to Int 275 SchemaUtils.checkNumericType(schema, $(userCol)) 276 SchemaUtils.checkNumericType(schema, $(itemCol)) 277 SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) 278 } 279 280 @Since("1.5.0") 281 override def copy(extra: ParamMap): ALSModel = { 282 val copied = new ALSModel(uid, rank, userFactors, itemFactors) 283 copyValues(copied, extra).setParent(parent) 284 } 285 286 @Since("1.6.0") 287 override def write: MLWriter = new ALSModel.ALSModelWriter(this) 288} 289 290@Since("1.6.0") 291object ALSModel extends MLReadable[ALSModel] { 292 293 @Since("1.6.0") 294 override def read: MLReader[ALSModel] = new ALSModelReader 295 296 @Since("1.6.0") 297 override def load(path: String): ALSModel = super.load(path) 298 299 private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { 300 301 override protected def saveImpl(path: String): Unit = { 302 val extraMetadata = "rank" -> instance.rank 303 DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) 304 val userPath = new Path(path, "userFactors").toString 305 instance.userFactors.write.format("parquet").save(userPath) 306 val itemPath = new Path(path, "itemFactors").toString 307 instance.itemFactors.write.format("parquet").save(itemPath) 308 } 309 } 310 311 private class ALSModelReader extends MLReader[ALSModel] { 312 313 /** Checked against metadata when loading model */ 314 private val className = classOf[ALSModel].getName 315 316 override def load(path: String): ALSModel = { 317 val metadata = DefaultParamsReader.loadMetadata(path, sc, className) 318 implicit val format = DefaultFormats 319 val rank = (metadata.metadata \ "rank").extract[Int] 320 val userPath = new Path(path, "userFactors").toString 321 val userFactors = sparkSession.read.format("parquet").load(userPath) 322 val itemPath = new Path(path, "itemFactors").toString 323 val itemFactors = sparkSession.read.format("parquet").load(itemPath) 324 325 val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) 326 327 DefaultParamsReader.getAndSetParams(model, metadata) 328 model 329 } 330 } 331} 332 333/** 334 * Alternating Least Squares (ALS) matrix factorization. 335 * 336 * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, 337 * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices. 338 * The general approach is iterative. During each iteration, one of the factor matrices is held 339 * constant, while the other is solved for using least squares. The newly-solved factor matrix is 340 * then held constant while solving for the other factor matrix. 341 * 342 * This is a blocked implementation of the ALS factorization algorithm that groups the two sets 343 * of factors (referred to as "users" and "products") into blocks and reduces communication by only 344 * sending one copy of each user vector to each product block on each iteration, and only for the 345 * product blocks that need that user's feature vector. This is achieved by pre-computing some 346 * information about the ratings matrix to determine the "out-links" of each user (which blocks of 347 * products it will contribute to) and "in-link" information for each product (which of the feature 348 * vectors it receives from each user block it will depend on). This allows us to send only an 349 * array of feature vectors between each user block and product block, and have the product block 350 * find the users' ratings and update the products based on these messages. 351 * 352 * For implicit preference data, the algorithm used is based on 353 * "Collaborative Filtering for Implicit Feedback Datasets", available at 354 * http://dx.doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here. 355 * 356 * Essentially instead of finding the low-rank approximations to the rating matrix `R`, 357 * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if 358 * r is greater than 0 and 0 if r is less than or equal to 0. The ratings then act as 'confidence' 359 * values related to strength of indicated user 360 * preferences rather than explicit ratings given to items. 361 */ 362@Since("1.3.0") 363class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams 364 with DefaultParamsWritable { 365 366 import org.apache.spark.ml.recommendation.ALS.Rating 367 368 @Since("1.4.0") 369 def this() = this(Identifiable.randomUID("als")) 370 371 /** @group setParam */ 372 @Since("1.3.0") 373 def setRank(value: Int): this.type = set(rank, value) 374 375 /** @group setParam */ 376 @Since("1.3.0") 377 def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value) 378 379 /** @group setParam */ 380 @Since("1.3.0") 381 def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value) 382 383 /** @group setParam */ 384 @Since("1.3.0") 385 def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value) 386 387 /** @group setParam */ 388 @Since("1.3.0") 389 def setAlpha(value: Double): this.type = set(alpha, value) 390 391 /** @group setParam */ 392 @Since("1.3.0") 393 def setUserCol(value: String): this.type = set(userCol, value) 394 395 /** @group setParam */ 396 @Since("1.3.0") 397 def setItemCol(value: String): this.type = set(itemCol, value) 398 399 /** @group setParam */ 400 @Since("1.3.0") 401 def setRatingCol(value: String): this.type = set(ratingCol, value) 402 403 /** @group setParam */ 404 @Since("1.3.0") 405 def setPredictionCol(value: String): this.type = set(predictionCol, value) 406 407 /** @group setParam */ 408 @Since("1.3.0") 409 def setMaxIter(value: Int): this.type = set(maxIter, value) 410 411 /** @group setParam */ 412 @Since("1.3.0") 413 def setRegParam(value: Double): this.type = set(regParam, value) 414 415 /** @group setParam */ 416 @Since("1.3.0") 417 def setNonnegative(value: Boolean): this.type = set(nonnegative, value) 418 419 /** @group setParam */ 420 @Since("1.4.0") 421 def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) 422 423 /** @group setParam */ 424 @Since("1.3.0") 425 def setSeed(value: Long): this.type = set(seed, value) 426 427 /** @group expertSetParam */ 428 @Since("2.0.0") 429 def setIntermediateStorageLevel(value: String): this.type = set(intermediateStorageLevel, value) 430 431 /** @group expertSetParam */ 432 @Since("2.0.0") 433 def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value) 434 435 /** 436 * Sets both numUserBlocks and numItemBlocks to the specific value. 437 * 438 * @group setParam 439 */ 440 @Since("1.3.0") 441 def setNumBlocks(value: Int): this.type = { 442 setNumUserBlocks(value) 443 setNumItemBlocks(value) 444 this 445 } 446 447 @Since("2.0.0") 448 override def fit(dataset: Dataset[_]): ALSModel = { 449 transformSchema(dataset.schema) 450 import dataset.sparkSession.implicits._ 451 452 val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) 453 val ratings = dataset 454 .select(checkedCast(col($(userCol)).cast(DoubleType)), 455 checkedCast(col($(itemCol)).cast(DoubleType)), r) 456 .rdd 457 .map { row => 458 Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) 459 } 460 val instrLog = Instrumentation.create(this, ratings) 461 instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, 462 userCol, itemCol, ratingCol, predictionCol, maxIter, 463 regParam, nonnegative, checkpointInterval, seed) 464 val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank), 465 numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), 466 maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), 467 alpha = $(alpha), nonnegative = $(nonnegative), 468 intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)), 469 finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)), 470 checkpointInterval = $(checkpointInterval), seed = $(seed)) 471 val userDF = userFactors.toDF("id", "features") 472 val itemDF = itemFactors.toDF("id", "features") 473 val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) 474 instrLog.logSuccess(model) 475 copyValues(model) 476 } 477 478 @Since("1.3.0") 479 override def transformSchema(schema: StructType): StructType = { 480 validateAndTransformSchema(schema) 481 } 482 483 @Since("1.5.0") 484 override def copy(extra: ParamMap): ALS = defaultCopy(extra) 485} 486 487 488/** 489 * :: DeveloperApi :: 490 * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is 491 * exposed as a developer API for users who do need other ID types. But it is not recommended 492 * because it increases the shuffle size and memory requirement during training. For simplicity, 493 * users and items must have the same type. The number of distinct users/items should be smaller 494 * than 2 billion. 495 */ 496@DeveloperApi 497object ALS extends DefaultParamsReadable[ALS] with Logging { 498 499 /** 500 * :: DeveloperApi :: 501 * Rating class for better code readability. 502 */ 503 @DeveloperApi 504 case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) 505 506 @Since("1.6.0") 507 override def load(path: String): ALS = super.load(path) 508 509 /** Trait for least squares solvers applied to the normal equation. */ 510 private[recommendation] trait LeastSquaresNESolver extends Serializable { 511 /** Solves a least squares problem with regularization (possibly with other constraints). */ 512 def solve(ne: NormalEquation, lambda: Double): Array[Float] 513 } 514 515 /** Cholesky solver for least square problems. */ 516 private[recommendation] class CholeskySolver extends LeastSquaresNESolver { 517 518 /** 519 * Solves a least squares problem with L2 regularization: 520 * 521 * min norm(A x - b)^2^ + lambda * norm(x)^2^ 522 * 523 * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances) 524 * @param lambda regularization constant 525 * @return the solution x 526 */ 527 override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { 528 val k = ne.k 529 // Add scaled lambda to the diagonals of AtA. 530 var i = 0 531 var j = 2 532 while (i < ne.triK) { 533 ne.ata(i) += lambda 534 i += j 535 j += 1 536 } 537 CholeskyDecomposition.solve(ne.ata, ne.atb) 538 val x = new Array[Float](k) 539 i = 0 540 while (i < k) { 541 x(i) = ne.atb(i).toFloat 542 i += 1 543 } 544 ne.reset() 545 x 546 } 547 } 548 549 /** NNLS solver. */ 550 private[recommendation] class NNLSSolver extends LeastSquaresNESolver { 551 private var rank: Int = -1 552 private var workspace: NNLS.Workspace = _ 553 private var ata: Array[Double] = _ 554 private var initialized: Boolean = false 555 556 private def initialize(rank: Int): Unit = { 557 if (!initialized) { 558 this.rank = rank 559 workspace = NNLS.createWorkspace(rank) 560 ata = new Array[Double](rank * rank) 561 initialized = true 562 } else { 563 require(this.rank == rank) 564 } 565 } 566 567 /** 568 * Solves a nonnegative least squares problem with L2 regularization: 569 * 570 * min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^ 571 * subject to x >= 0 572 */ 573 override def solve(ne: NormalEquation, lambda: Double): Array[Float] = { 574 val rank = ne.k 575 initialize(rank) 576 fillAtA(ne.ata, lambda) 577 val x = NNLS.solve(ata, ne.atb, workspace) 578 ne.reset() 579 x.map(x => x.toFloat) 580 } 581 582 /** 583 * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square 584 * matrix that it represents, storing it into destMatrix. 585 */ 586 private def fillAtA(triAtA: Array[Double], lambda: Double) { 587 var i = 0 588 var pos = 0 589 var a = 0.0 590 while (i < rank) { 591 var j = 0 592 while (j <= i) { 593 a = triAtA(pos) 594 ata(i * rank + j) = a 595 ata(j * rank + i) = a 596 pos += 1 597 j += 1 598 } 599 ata(i * rank + i) += lambda 600 i += 1 601 } 602 } 603 } 604 605 /** 606 * Representing a normal equation to solve the following weighted least squares problem: 607 * 608 * minimize \sum,,i,, c,,i,, (a,,i,,^T^ x - b,,i,,)^2^ + lambda * x^T^ x. 609 * 610 * Its normal equation is given by 611 * 612 * \sum,,i,, c,,i,, (a,,i,, a,,i,,^T^ x - b,,i,, a,,i,,) + lambda * x = 0. 613 */ 614 private[recommendation] class NormalEquation(val k: Int) extends Serializable { 615 616 /** Number of entries in the upper triangular part of a k-by-k matrix. */ 617 val triK = k * (k + 1) / 2 618 /** A^T^ * A */ 619 val ata = new Array[Double](triK) 620 /** A^T^ * b */ 621 val atb = new Array[Double](k) 622 623 private val da = new Array[Double](k) 624 private val upper = "U" 625 626 private def copyToDouble(a: Array[Float]): Unit = { 627 var i = 0 628 while (i < k) { 629 da(i) = a(i) 630 i += 1 631 } 632 } 633 634 /** Adds an observation. */ 635 def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = { 636 require(c >= 0.0) 637 require(a.length == k) 638 copyToDouble(a) 639 blas.dspr(upper, k, c, da, 1, ata) 640 if (b != 0.0) { 641 blas.daxpy(k, c * b, da, 1, atb, 1) 642 } 643 this 644 } 645 646 /** Merges another normal equation object. */ 647 def merge(other: NormalEquation): this.type = { 648 require(other.k == k) 649 blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) 650 blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) 651 this 652 } 653 654 /** Resets everything to zero, which should be called after each solve. */ 655 def reset(): Unit = { 656 ju.Arrays.fill(ata, 0.0) 657 ju.Arrays.fill(atb, 0.0) 658 } 659 } 660 661 /** 662 * :: DeveloperApi :: 663 * Implementation of the ALS algorithm. 664 */ 665 @DeveloperApi 666 def train[ID: ClassTag]( // scalastyle:ignore 667 ratings: RDD[Rating[ID]], 668 rank: Int = 10, 669 numUserBlocks: Int = 10, 670 numItemBlocks: Int = 10, 671 maxIter: Int = 10, 672 regParam: Double = 1.0, 673 implicitPrefs: Boolean = false, 674 alpha: Double = 1.0, 675 nonnegative: Boolean = false, 676 intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, 677 finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, 678 checkpointInterval: Int = 10, 679 seed: Long = 0L)( 680 implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { 681 require(intermediateRDDStorageLevel != StorageLevel.NONE, 682 "ALS is not designed to run without persisting intermediate RDDs.") 683 val sc = ratings.sparkContext 684 val userPart = new ALSPartitioner(numUserBlocks) 685 val itemPart = new ALSPartitioner(numItemBlocks) 686 val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) 687 val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions) 688 val solver = if (nonnegative) new NNLSSolver else new CholeskySolver 689 val blockRatings = partitionRatings(ratings, userPart, itemPart) 690 .persist(intermediateRDDStorageLevel) 691 val (userInBlocks, userOutBlocks) = 692 makeBlocks("user", blockRatings, userPart, itemPart, intermediateRDDStorageLevel) 693 // materialize blockRatings and user blocks 694 userOutBlocks.count() 695 val swappedBlockRatings = blockRatings.map { 696 case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) => 697 ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings)) 698 } 699 val (itemInBlocks, itemOutBlocks) = 700 makeBlocks("item", swappedBlockRatings, itemPart, userPart, intermediateRDDStorageLevel) 701 // materialize item blocks 702 itemOutBlocks.count() 703 val seedGen = new XORShiftRandom(seed) 704 var userFactors = initialize(userInBlocks, rank, seedGen.nextLong()) 705 var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) 706 var previousCheckpointFile: Option[String] = None 707 val shouldCheckpoint: Int => Boolean = (iter) => 708 sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0) 709 val deletePreviousCheckpointFile: () => Unit = () => 710 previousCheckpointFile.foreach { file => 711 try { 712 val checkpointFile = new Path(file) 713 checkpointFile.getFileSystem(sc.hadoopConfiguration).delete(checkpointFile, true) 714 } catch { 715 case e: IOException => 716 logWarning(s"Cannot delete checkpoint file $file:", e) 717 } 718 } 719 if (implicitPrefs) { 720 for (iter <- 1 to maxIter) { 721 userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel) 722 val previousItemFactors = itemFactors 723 itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, 724 userLocalIndexEncoder, implicitPrefs, alpha, solver) 725 previousItemFactors.unpersist() 726 itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) 727 // TODO: Generalize PeriodicGraphCheckpointer and use it here. 728 val deps = itemFactors.dependencies 729 if (shouldCheckpoint(iter)) { 730 itemFactors.checkpoint() // itemFactors gets materialized in computeFactors 731 } 732 val previousUserFactors = userFactors 733 userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, 734 itemLocalIndexEncoder, implicitPrefs, alpha, solver) 735 if (shouldCheckpoint(iter)) { 736 ALS.cleanShuffleDependencies(sc, deps) 737 deletePreviousCheckpointFile() 738 previousCheckpointFile = itemFactors.getCheckpointFile 739 } 740 previousUserFactors.unpersist() 741 } 742 } else { 743 for (iter <- 0 until maxIter) { 744 itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, 745 userLocalIndexEncoder, solver = solver) 746 if (shouldCheckpoint(iter)) { 747 val deps = itemFactors.dependencies 748 itemFactors.checkpoint() 749 itemFactors.count() // checkpoint item factors and cut lineage 750 ALS.cleanShuffleDependencies(sc, deps) 751 deletePreviousCheckpointFile() 752 previousCheckpointFile = itemFactors.getCheckpointFile 753 } 754 userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, 755 itemLocalIndexEncoder, solver = solver) 756 } 757 } 758 val userIdAndFactors = userInBlocks 759 .mapValues(_.srcIds) 760 .join(userFactors) 761 .mapPartitions({ items => 762 items.flatMap { case (_, (ids, factors)) => 763 ids.view.zip(factors) 764 } 765 // Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks 766 // and userFactors. 767 }, preservesPartitioning = true) 768 .setName("userFactors") 769 .persist(finalRDDStorageLevel) 770 val itemIdAndFactors = itemInBlocks 771 .mapValues(_.srcIds) 772 .join(itemFactors) 773 .mapPartitions({ items => 774 items.flatMap { case (_, (ids, factors)) => 775 ids.view.zip(factors) 776 } 777 }, preservesPartitioning = true) 778 .setName("itemFactors") 779 .persist(finalRDDStorageLevel) 780 if (finalRDDStorageLevel != StorageLevel.NONE) { 781 userIdAndFactors.count() 782 itemFactors.unpersist() 783 itemIdAndFactors.count() 784 userInBlocks.unpersist() 785 userOutBlocks.unpersist() 786 itemInBlocks.unpersist() 787 itemOutBlocks.unpersist() 788 blockRatings.unpersist() 789 } 790 (userIdAndFactors, itemIdAndFactors) 791 } 792 793 /** 794 * Factor block that stores factors (Array[Float]) in an Array. 795 */ 796 private type FactorBlock = Array[Array[Float]] 797 798 /** 799 * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to 800 * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the 801 * src factors in this block to send to dst block 0. 802 */ 803 private type OutBlock = Array[Array[Int]] 804 805 /** 806 * In-link block for computing src (user/item) factors. This includes the original src IDs 807 * of the elements within this block as well as encoded dst (item/user) indices and corresponding 808 * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original 809 * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices. 810 * For example, if we have an in-link record 811 * 812 * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0}, 813 * 814 * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which 815 * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3). 816 * 817 * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can 818 * compute src factors one after another using only one normal equation instance. 819 * 820 * @param srcIds src ids (ordered) 821 * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and 822 * ratings are associated with srcIds(i). 823 * @param dstEncodedIndices encoded dst indices 824 * @param ratings ratings 825 * @see [[LocalIndexEncoder]] 826 */ 827 private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag]( 828 srcIds: Array[ID], 829 dstPtrs: Array[Int], 830 dstEncodedIndices: Array[Int], 831 ratings: Array[Float]) { 832 /** Size of the block. */ 833 def size: Int = ratings.length 834 require(dstEncodedIndices.length == size) 835 require(dstPtrs.length == srcIds.length + 1) 836 } 837 838 /** 839 * Initializes factors randomly given the in-link blocks. 840 * 841 * @param inBlocks in-link blocks 842 * @param rank rank 843 * @return initialized factor blocks 844 */ 845 private def initialize[ID]( 846 inBlocks: RDD[(Int, InBlock[ID])], 847 rank: Int, 848 seed: Long): RDD[(Int, FactorBlock)] = { 849 // Choose a unit vector uniformly at random from the unit sphere, but from the 850 // "first quadrant" where all elements are nonnegative. This can be done by choosing 851 // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. 852 // This appears to create factorizations that have a slightly better reconstruction 853 // (<1%) compared picking elements uniformly at random in [0,1]. 854 inBlocks.map { case (srcBlockId, inBlock) => 855 val random = new XORShiftRandom(byteswap64(seed ^ srcBlockId)) 856 val factors = Array.fill(inBlock.srcIds.length) { 857 val factor = Array.fill(rank)(random.nextGaussian().toFloat) 858 val nrm = blas.snrm2(rank, factor, 1) 859 blas.sscal(rank, 1.0f / nrm, factor, 1) 860 factor 861 } 862 (srcBlockId, factors) 863 } 864 } 865 866 /** 867 * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays. 868 */ 869 private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag]( 870 srcIds: Array[ID], 871 dstIds: Array[ID], 872 ratings: Array[Float]) { 873 /** Size of the block. */ 874 def size: Int = srcIds.length 875 require(dstIds.length == srcIds.length) 876 require(ratings.length == srcIds.length) 877 } 878 879 /** 880 * Builder for [[RatingBlock]]. `mutable.ArrayBuilder` is used to avoid boxing/unboxing. 881 */ 882 private[recommendation] class RatingBlockBuilder[@specialized(Int, Long) ID: ClassTag] 883 extends Serializable { 884 885 private val srcIds = mutable.ArrayBuilder.make[ID] 886 private val dstIds = mutable.ArrayBuilder.make[ID] 887 private val ratings = mutable.ArrayBuilder.make[Float] 888 var size = 0 889 890 /** Adds a rating. */ 891 def add(r: Rating[ID]): this.type = { 892 size += 1 893 srcIds += r.user 894 dstIds += r.item 895 ratings += r.rating 896 this 897 } 898 899 /** Merges another [[RatingBlockBuilder]]. */ 900 def merge(other: RatingBlock[ID]): this.type = { 901 size += other.srcIds.length 902 srcIds ++= other.srcIds 903 dstIds ++= other.dstIds 904 ratings ++= other.ratings 905 this 906 } 907 908 /** Builds a [[RatingBlock]]. */ 909 def build(): RatingBlock[ID] = { 910 RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result()) 911 } 912 } 913 914 /** 915 * Partitions raw ratings into blocks. 916 * 917 * @param ratings raw ratings 918 * @param srcPart partitioner for src IDs 919 * @param dstPart partitioner for dst IDs 920 * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) 921 */ 922 private def partitionRatings[ID: ClassTag]( 923 ratings: RDD[Rating[ID]], 924 srcPart: Partitioner, 925 dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = { 926 927 /* The implementation produces the same result as the following but generates less objects. 928 929 ratings.map { r => 930 ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r) 931 }.aggregateByKey(new RatingBlockBuilder)( 932 seqOp = (b, r) => b.add(r), 933 combOp = (b0, b1) => b0.merge(b1.build())) 934 .mapValues(_.build()) 935 */ 936 937 val numPartitions = srcPart.numPartitions * dstPart.numPartitions 938 ratings.mapPartitions { iter => 939 val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID]) 940 iter.flatMap { r => 941 val srcBlockId = srcPart.getPartition(r.user) 942 val dstBlockId = dstPart.getPartition(r.item) 943 val idx = srcBlockId + srcPart.numPartitions * dstBlockId 944 val builder = builders(idx) 945 builder.add(r) 946 if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k 947 builders(idx) = new RatingBlockBuilder 948 Iterator.single(((srcBlockId, dstBlockId), builder.build())) 949 } else { 950 Iterator.empty 951 } 952 } ++ { 953 builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) => 954 val srcBlockId = idx % srcPart.numPartitions 955 val dstBlockId = idx / srcPart.numPartitions 956 ((srcBlockId, dstBlockId), block.build()) 957 } 958 } 959 }.groupByKey().mapValues { blocks => 960 val builder = new RatingBlockBuilder[ID] 961 blocks.foreach(builder.merge) 962 builder.build() 963 }.setName("ratingBlocks") 964 } 965 966 /** 967 * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. 968 * 969 * @param encoder encoder for dst indices 970 */ 971 private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag]( 972 encoder: LocalIndexEncoder)( 973 implicit ord: Ordering[ID]) { 974 975 private val srcIds = mutable.ArrayBuilder.make[ID] 976 private val dstEncodedIndices = mutable.ArrayBuilder.make[Int] 977 private val ratings = mutable.ArrayBuilder.make[Float] 978 979 /** 980 * Adds a dst block of (srcId, dstLocalIndex, rating) tuples. 981 * 982 * @param dstBlockId dst block ID 983 * @param srcIds original src IDs 984 * @param dstLocalIndices dst local indices 985 * @param ratings ratings 986 */ 987 def add( 988 dstBlockId: Int, 989 srcIds: Array[ID], 990 dstLocalIndices: Array[Int], 991 ratings: Array[Float]): this.type = { 992 val sz = srcIds.length 993 require(dstLocalIndices.length == sz) 994 require(ratings.length == sz) 995 this.srcIds ++= srcIds 996 this.ratings ++= ratings 997 var j = 0 998 while (j < sz) { 999 this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j)) 1000 j += 1 1001 } 1002 this 1003 } 1004 1005 /** Builds a [[UncompressedInBlock]]. */ 1006 def build(): UncompressedInBlock[ID] = { 1007 new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result()) 1008 } 1009 } 1010 1011 /** 1012 * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays. 1013 */ 1014 private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag]( 1015 val srcIds: Array[ID], 1016 val dstEncodedIndices: Array[Int], 1017 val ratings: Array[Float])( 1018 implicit ord: Ordering[ID]) { 1019 1020 /** Size the of block. */ 1021 def length: Int = srcIds.length 1022 1023 /** 1024 * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a 1025 * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. 1026 * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. 1027 */ 1028 def compress(): InBlock[ID] = { 1029 val sz = length 1030 assert(sz > 0, "Empty in-link block should not exist.") 1031 sort() 1032 val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID] 1033 val dstCountsBuilder = mutable.ArrayBuilder.make[Int] 1034 var preSrcId = srcIds(0) 1035 uniqueSrcIdsBuilder += preSrcId 1036 var curCount = 1 1037 var i = 1 1038 var j = 0 1039 while (i < sz) { 1040 val srcId = srcIds(i) 1041 if (srcId != preSrcId) { 1042 uniqueSrcIdsBuilder += srcId 1043 dstCountsBuilder += curCount 1044 preSrcId = srcId 1045 j += 1 1046 curCount = 0 1047 } 1048 curCount += 1 1049 i += 1 1050 } 1051 dstCountsBuilder += curCount 1052 val uniqueSrcIds = uniqueSrcIdsBuilder.result() 1053 val numUniqueSrdIds = uniqueSrcIds.length 1054 val dstCounts = dstCountsBuilder.result() 1055 val dstPtrs = new Array[Int](numUniqueSrdIds + 1) 1056 var sum = 0 1057 i = 0 1058 while (i < numUniqueSrdIds) { 1059 sum += dstCounts(i) 1060 i += 1 1061 dstPtrs(i) = sum 1062 } 1063 InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings) 1064 } 1065 1066 private def sort(): Unit = { 1067 val sz = length 1068 // Since there might be interleaved log messages, we insert a unique id for easy pairing. 1069 val sortId = Utils.random.nextInt() 1070 logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)") 1071 val start = System.nanoTime() 1072 val sorter = new Sorter(new UncompressedInBlockSort[ID]) 1073 sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]]) 1074 val duration = (System.nanoTime() - start) / 1e9 1075 logDebug(s"Sorting took $duration seconds. (sortId = $sortId)") 1076 } 1077 } 1078 1079 /** 1080 * A wrapper that holds a primitive key. 1081 * 1082 * @see [[UncompressedInBlockSort]] 1083 */ 1084 private class KeyWrapper[@specialized(Int, Long) ID: ClassTag]( 1085 implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] { 1086 1087 var key: ID = _ 1088 1089 override def compare(that: KeyWrapper[ID]): Int = { 1090 ord.compare(key, that.key) 1091 } 1092 1093 def setKey(key: ID): this.type = { 1094 this.key = key 1095 this 1096 } 1097 } 1098 1099 /** 1100 * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]]. 1101 */ 1102 private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag]( 1103 implicit ord: Ordering[ID]) 1104 extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] { 1105 1106 override def newKey(): KeyWrapper[ID] = new KeyWrapper() 1107 1108 override def getKey( 1109 data: UncompressedInBlock[ID], 1110 pos: Int, 1111 reuse: KeyWrapper[ID]): KeyWrapper[ID] = { 1112 if (reuse == null) { 1113 new KeyWrapper().setKey(data.srcIds(pos)) 1114 } else { 1115 reuse.setKey(data.srcIds(pos)) 1116 } 1117 } 1118 1119 override def getKey( 1120 data: UncompressedInBlock[ID], 1121 pos: Int): KeyWrapper[ID] = { 1122 getKey(data, pos, null) 1123 } 1124 1125 private def swapElements[@specialized(Int, Float) T]( 1126 data: Array[T], 1127 pos0: Int, 1128 pos1: Int): Unit = { 1129 val tmp = data(pos0) 1130 data(pos0) = data(pos1) 1131 data(pos1) = tmp 1132 } 1133 1134 override def swap(data: UncompressedInBlock[ID], pos0: Int, pos1: Int): Unit = { 1135 swapElements(data.srcIds, pos0, pos1) 1136 swapElements(data.dstEncodedIndices, pos0, pos1) 1137 swapElements(data.ratings, pos0, pos1) 1138 } 1139 1140 override def copyRange( 1141 src: UncompressedInBlock[ID], 1142 srcPos: Int, 1143 dst: UncompressedInBlock[ID], 1144 dstPos: Int, 1145 length: Int): Unit = { 1146 System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length) 1147 System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length) 1148 System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length) 1149 } 1150 1151 override def allocate(length: Int): UncompressedInBlock[ID] = { 1152 new UncompressedInBlock( 1153 new Array[ID](length), new Array[Int](length), new Array[Float](length)) 1154 } 1155 1156 override def copyElement( 1157 src: UncompressedInBlock[ID], 1158 srcPos: Int, 1159 dst: UncompressedInBlock[ID], 1160 dstPos: Int): Unit = { 1161 dst.srcIds(dstPos) = src.srcIds(srcPos) 1162 dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos) 1163 dst.ratings(dstPos) = src.ratings(srcPos) 1164 } 1165 } 1166 1167 /** 1168 * Creates in-blocks and out-blocks from rating blocks. 1169 * 1170 * @param prefix prefix for in/out-block names 1171 * @param ratingBlocks rating blocks 1172 * @param srcPart partitioner for src IDs 1173 * @param dstPart partitioner for dst IDs 1174 * @return (in-blocks, out-blocks) 1175 */ 1176 private def makeBlocks[ID: ClassTag]( 1177 prefix: String, 1178 ratingBlocks: RDD[((Int, Int), RatingBlock[ID])], 1179 srcPart: Partitioner, 1180 dstPart: Partitioner, 1181 storageLevel: StorageLevel)( 1182 implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = { 1183 val inBlocks = ratingBlocks.map { 1184 case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => 1185 // The implementation is a faster version of 1186 // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap 1187 val start = System.nanoTime() 1188 val dstIdSet = new OpenHashSet[ID](1 << 20) 1189 dstIds.foreach(dstIdSet.add) 1190 val sortedDstIds = new Array[ID](dstIdSet.size) 1191 var i = 0 1192 var pos = dstIdSet.nextPos(0) 1193 while (pos != -1) { 1194 sortedDstIds(i) = dstIdSet.getValue(pos) 1195 pos = dstIdSet.nextPos(pos + 1) 1196 i += 1 1197 } 1198 assert(i == dstIdSet.size) 1199 Sorting.quickSort(sortedDstIds) 1200 val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length) 1201 i = 0 1202 while (i < sortedDstIds.length) { 1203 dstIdToLocalIndex.update(sortedDstIds(i), i) 1204 i += 1 1205 } 1206 logDebug( 1207 "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.") 1208 val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply) 1209 (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings)) 1210 }.groupByKey(new ALSPartitioner(srcPart.numPartitions)) 1211 .mapValues { iter => 1212 val builder = 1213 new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions)) 1214 iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => 1215 builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) 1216 } 1217 builder.build().compress() 1218 }.setName(prefix + "InBlocks") 1219 .persist(storageLevel) 1220 val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) => 1221 val encoder = new LocalIndexEncoder(dstPart.numPartitions) 1222 val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) 1223 var i = 0 1224 val seen = new Array[Boolean](dstPart.numPartitions) 1225 while (i < srcIds.length) { 1226 var j = dstPtrs(i) 1227 ju.Arrays.fill(seen, false) 1228 while (j < dstPtrs(i + 1)) { 1229 val dstBlockId = encoder.blockId(dstEncodedIndices(j)) 1230 if (!seen(dstBlockId)) { 1231 activeIds(dstBlockId) += i // add the local index in this out-block 1232 seen(dstBlockId) = true 1233 } 1234 j += 1 1235 } 1236 i += 1 1237 } 1238 activeIds.map { x => 1239 x.result() 1240 } 1241 }.setName(prefix + "OutBlocks") 1242 .persist(storageLevel) 1243 (inBlocks, outBlocks) 1244 } 1245 1246 /** 1247 * Compute dst factors by constructing and solving least square problems. 1248 * 1249 * @param srcFactorBlocks src factors 1250 * @param srcOutBlocks src out-blocks 1251 * @param dstInBlocks dst in-blocks 1252 * @param rank rank 1253 * @param regParam regularization constant 1254 * @param srcEncoder encoder for src local indices 1255 * @param implicitPrefs whether to use implicit preference 1256 * @param alpha the alpha constant in the implicit preference formulation 1257 * @param solver solver for least squares problems 1258 * @return dst factors 1259 */ 1260 private def computeFactors[ID]( 1261 srcFactorBlocks: RDD[(Int, FactorBlock)], 1262 srcOutBlocks: RDD[(Int, OutBlock)], 1263 dstInBlocks: RDD[(Int, InBlock[ID])], 1264 rank: Int, 1265 regParam: Double, 1266 srcEncoder: LocalIndexEncoder, 1267 implicitPrefs: Boolean = false, 1268 alpha: Double = 1.0, 1269 solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = { 1270 val numSrcBlocks = srcFactorBlocks.partitions.length 1271 val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None 1272 val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { 1273 case (srcBlockId, (srcOutBlock, srcFactors)) => 1274 srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) => 1275 (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) 1276 } 1277 } 1278 val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length)) 1279 dstInBlocks.join(merged).mapValues { 1280 case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => 1281 val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) 1282 srcFactors.foreach { case (srcBlockId, factors) => 1283 sortedSrcFactors(srcBlockId) = factors 1284 } 1285 val dstFactors = new Array[Array[Float]](dstIds.length) 1286 var j = 0 1287 val ls = new NormalEquation(rank) 1288 while (j < dstIds.length) { 1289 ls.reset() 1290 if (implicitPrefs) { 1291 ls.merge(YtY.get) 1292 } 1293 var i = srcPtrs(j) 1294 var numExplicits = 0 1295 while (i < srcPtrs(j + 1)) { 1296 val encoded = srcEncodedIndices(i) 1297 val blockId = srcEncoder.blockId(encoded) 1298 val localIndex = srcEncoder.localIndex(encoded) 1299 val srcFactor = sortedSrcFactors(blockId)(localIndex) 1300 val rating = ratings(i) 1301 if (implicitPrefs) { 1302 // Extension to the original paper to handle b < 0. confidence is a function of |b| 1303 // instead so that it is never negative. c1 is confidence - 1.0. 1304 val c1 = alpha * math.abs(rating) 1305 // For rating <= 0, the corresponding preference is 0. So the term below is only added 1306 // for rating > 0. Because YtY is already added, we need to adjust the scaling here. 1307 if (rating > 0) { 1308 numExplicits += 1 1309 ls.add(srcFactor, (c1 + 1.0) / c1, c1) 1310 } 1311 } else { 1312 ls.add(srcFactor, rating) 1313 numExplicits += 1 1314 } 1315 i += 1 1316 } 1317 // Weight lambda by the number of explicit ratings based on the ALS-WR paper. 1318 dstFactors(j) = solver.solve(ls, numExplicits * regParam) 1319 j += 1 1320 } 1321 dstFactors 1322 } 1323 } 1324 1325 /** 1326 * Computes the Gramian matrix of user or item factors, which is only used in implicit preference. 1327 * Caching of the input factors is handled in [[ALS#train]]. 1328 */ 1329 private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = { 1330 factorBlocks.values.aggregate(new NormalEquation(rank))( 1331 seqOp = (ne, factors) => { 1332 factors.foreach(ne.add(_, 0.0)) 1333 ne 1334 }, 1335 combOp = (ne1, ne2) => ne1.merge(ne2)) 1336 } 1337 1338 /** 1339 * Encoder for storing (blockId, localIndex) into a single integer. 1340 * 1341 * We use the leading bits (including the sign bit) to store the block id and the rest to store 1342 * the local index. This is based on the assumption that users/items are approximately evenly 1343 * partitioned. With this assumption, we should be able to encode two billion distinct values. 1344 * 1345 * @param numBlocks number of blocks 1346 */ 1347 private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable { 1348 1349 require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.") 1350 1351 private[this] final val numLocalIndexBits = 1352 math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31) 1353 private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1 1354 1355 /** Encodes a (blockId, localIndex) into a single integer. */ 1356 def encode(blockId: Int, localIndex: Int): Int = { 1357 require(blockId < numBlocks) 1358 require((localIndex & ~localIndexMask) == 0) 1359 (blockId << numLocalIndexBits) | localIndex 1360 } 1361 1362 /** Gets the block id from an encoded index. */ 1363 @inline 1364 def blockId(encoded: Int): Int = { 1365 encoded >>> numLocalIndexBits 1366 } 1367 1368 /** Gets the local index from an encoded index. */ 1369 @inline 1370 def localIndex(encoded: Int): Int = { 1371 encoded & localIndexMask 1372 } 1373 } 1374 1375 /** 1376 * Partitioner used by ALS. We require that getPartition is a projection. That is, for any key k, 1377 * we have getPartition(getPartition(k)) = getPartition(k). Since the default HashPartitioner 1378 * satisfies this requirement, we simply use a type alias here. 1379 */ 1380 private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner 1381 1382 /** 1383 * Private function to clean up all of the shuffles files from the dependencies and their parents. 1384 */ 1385 private[spark] def cleanShuffleDependencies[T]( 1386 sc: SparkContext, 1387 deps: Seq[Dependency[_]], 1388 blocking: Boolean = false): Unit = { 1389 // If there is no reference tracking we skip clean up. 1390 sc.cleaner.foreach { cleaner => 1391 /** 1392 * Clean the shuffles & all of its parents. 1393 */ 1394 def cleanEagerly(dep: Dependency[_]): Unit = { 1395 if (dep.isInstanceOf[ShuffleDependency[_, _, _]]) { 1396 val shuffleId = dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId 1397 cleaner.doCleanupShuffle(shuffleId, blocking) 1398 } 1399 val rdd = dep.rdd 1400 val rddDeps = rdd.dependencies 1401 if (rdd.getStorageLevel == StorageLevel.NONE && rddDeps != null) { 1402 rddDeps.foreach(cleanEagerly) 1403 } 1404 } 1405 deps.foreach(cleanEagerly) 1406 } 1407 } 1408} 1409