1 /** 2 * @file methods/cf/decomposition_policies/svd_complete_method.hpp 3 * @author Haritha Nair 4 * 5 * Implementation of the SVD complete incremental method for use in the 6 * Collaborative Filtering. 7 * 8 * mlpack is free software; you may redistribute it and/or modify it under the 9 * terms of the 3-clause BSD license. You should have received a copy of the 10 * 3-clause BSD license along with mlpack. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 14 #ifndef MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_SVD_COMPLETE_METHOD_HPP 15 #define MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_SVD_COMPLETE_METHOD_HPP 16 17 #include <mlpack/prereqs.hpp> 18 #include <mlpack/methods/amf/amf.hpp> 19 #include <mlpack/methods/amf/update_rules/nmf_als.hpp> 20 #include <mlpack/methods/amf/termination_policies/max_iteration_termination.hpp> 21 #include <mlpack/methods/amf/termination_policies/simple_residue_termination.hpp> 22 23 namespace mlpack { 24 namespace cf { 25 26 /** 27 * Implementation of the SVD complete incremental policy to act as a wrapper 28 * when accessing SVD complete decomposition from within CFType. 29 * 30 * An example of how to use SVDCompletePolicy in CF is shown below: 31 * 32 * @code 33 * extern arma::mat data; // data is a (user, item, rating) table. 34 * // Users for whom recommendations are generated. 35 * extern arma::Col<size_t> users; 36 * arma::Mat<size_t> recommendations; // Resulting recommendations. 37 * 38 * CFType<SVDCompletePolicy> cf(data); 39 * 40 * // Generate 10 recommendations for all users. 41 * cf.GetRecommendations(10, recommendations); 42 * @endcode 43 */ 44 class SVDCompletePolicy 45 { 46 public: 47 /** 48 * Apply Collaborative Filtering to the provided data set using the 49 * SVD complete incremental policy. 50 * 51 * @param * (data) Data matrix: dense matrix (coordinate lists) 52 * or sparse matrix(cleaned). 53 * @param cleanedData item user table in form of sparse matrix. 54 * @param rank Rank parameter for matrix factorization. 55 * @param maxIterations Maximum number of iterations. 56 * @param minResidue Residue required to terminate. 57 * @param mit Whether to terminate only when maxIterations is reached. 58 */ 59 template<typename MatType> Apply(const MatType &,const arma::sp_mat & cleanedData,const size_t rank,const size_t maxIterations,const double minResidue,const bool mit)60 void Apply(const MatType& /* data */, 61 const arma::sp_mat& cleanedData, 62 const size_t rank, 63 const size_t maxIterations, 64 const double minResidue, 65 const bool mit) 66 { 67 if (mit) 68 { 69 amf::MaxIterationTermination iter(maxIterations); 70 71 // Do singular value decomposition using complete incremental method 72 // using cleaned data in form of sparse matrix. 73 amf::AMF<amf::MaxIterationTermination, amf::RandomInitialization, 74 amf::SVDCompleteIncrementalLearning<arma::sp_mat>> svdci(iter); 75 76 svdci.Apply(cleanedData, rank, w, h); 77 } 78 else 79 { 80 amf::SimpleResidueTermination srt(minResidue, maxIterations); 81 82 // Do singular value decomposition using complete incremental method 83 // using cleaned data in form of sparse matrix. 84 amf::SVDCompleteIncrementalFactorizer<arma::sp_mat> svdci(srt); 85 86 svdci.Apply(cleanedData, rank, w, h); 87 } 88 } 89 90 /** 91 * Return predicted rating given user ID and item ID. 92 * 93 * @param user User ID. 94 * @param item Item ID. 95 */ GetRating(const size_t user,const size_t item) const96 double GetRating(const size_t user, const size_t item) const 97 { 98 double rating = arma::as_scalar(w.row(item) * h.col(user)); 99 return rating; 100 } 101 102 /** 103 * Get predicted ratings for a user. 104 * 105 * @param user User ID. 106 * @param rating Resulting rating vector. 107 */ GetRatingOfUser(const size_t user,arma::vec & rating) const108 void GetRatingOfUser(const size_t user, arma::vec& rating) const 109 { 110 rating = w * h.col(user); 111 } 112 113 /** 114 * Get the neighborhood and corresponding similarities for a set of users. 115 * 116 * @tparam NeighborSearchPolicy The policy to perform neighbor search. 117 * 118 * @param users Users whose neighborhood is to be computed. 119 * @param numUsersForSimilarity The number of neighbors returned for 120 * each user. 121 * @param neighborhood Neighbors represented by user IDs. 122 * @param similarities Similarity between each user and each of its 123 * neighbors. 124 */ 125 template<typename NeighborSearchPolicy> GetNeighborhood(const arma::Col<size_t> & users,const size_t numUsersForSimilarity,arma::Mat<size_t> & neighborhood,arma::mat & similarities) const126 void GetNeighborhood(const arma::Col<size_t>& users, 127 const size_t numUsersForSimilarity, 128 arma::Mat<size_t>& neighborhood, 129 arma::mat& similarities) const 130 { 131 // We want to avoid calculating the full rating matrix, so we will do 132 // nearest neighbor search only on the H matrix, using the observation that 133 // if the rating matrix X = W*H, then d(X.col(i), X.col(j)) = d(W H.col(i), 134 // W H.col(j)). This can be seen as nearest neighbor search on the H 135 // matrix with the Mahalanobis distance where M^{-1} = W^T W. So, we'll 136 // decompose M^{-1} = L L^T (the Cholesky decomposition), and then multiply 137 // H by L^T. Then we can perform nearest neighbor search. 138 arma::mat l = arma::chol(w.t() * w); 139 arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T. 140 141 // Temporarily store feature vector of queried users. 142 arma::mat query(stretchedH.n_rows, users.n_elem); 143 // Select feature vectors of queried users. 144 for (size_t i = 0; i < users.n_elem; ++i) 145 query.col(i) = stretchedH.col(users(i)); 146 147 NeighborSearchPolicy neighborSearch(stretchedH); 148 neighborSearch.Search( 149 query, numUsersForSimilarity, neighborhood, similarities); 150 } 151 152 //! Get the Item Matrix. W() const153 const arma::mat& W() const { return w; } 154 //! Get the User Matrix. H() const155 const arma::mat& H() const { return h; } 156 157 /** 158 * Serialization. 159 */ 160 template<typename Archive> serialize(Archive & ar,const unsigned int)161 void serialize(Archive& ar, const unsigned int /* version */) 162 { 163 ar & BOOST_SERIALIZATION_NVP(w); 164 ar & BOOST_SERIALIZATION_NVP(h); 165 } 166 167 private: 168 //! Item matrix. 169 arma::mat w; 170 //! User matrix. 171 arma::mat h; 172 }; 173 174 } // namespace cf 175 } // namespace mlpack 176 177 #endif 178