1 /** 2 * @file methods/cf/decomposition_policies/batch_svd_method.hpp 3 * @author Haritha Nair 4 * 5 * Implementation of the batch SVD method for use in Collaborative Filtering. 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 13 #ifndef MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_BATCH_SVD_METHOD_HPP 14 #define MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_BATCH_SVD_METHOD_HPP 15 16 #include <mlpack/prereqs.hpp> 17 #include <mlpack/methods/amf/amf.hpp> 18 #include <mlpack/methods/amf/update_rules/nmf_als.hpp> 19 #include <mlpack/methods/amf/termination_policies/simple_residue_termination.hpp> 20 #include <mlpack/methods/amf/termination_policies/max_iteration_termination.hpp> 21 22 namespace mlpack { 23 namespace cf { 24 25 /** 26 * Implementation of the Batch SVD policy to act as a wrapper when accessing 27 * Batch SVD from within CFType. 28 * 29 * An example of how to use BatchSVDPolicy in CF is shown below: 30 * 31 * @code 32 * extern arma::mat data; // data is a (user, item, rating) table. 33 * // Users for whom recommendations are generated. 34 * extern arma::Col<size_t> users; 35 * arma::Mat<size_t> recommendations; // Resulting recommendations. 36 * 37 * CFType<BatchSVDPolicy> cf(data); 38 * 39 * // Generate 10 recommendations for all users. 40 * cf.GetRecommendations(10, recommendations); 41 * @endcode 42 */ 43 class BatchSVDPolicy 44 { 45 public: 46 /** 47 * Apply Collaborative Filtering to the provided data set using the 48 * batch SVD method. 49 * 50 * @param * (data) Data matrix: dense matrix (coordinate lists) 51 * or sparse matrix(cleaned). 52 * @param cleanedData item user table in form of sparse matrix. 53 * @param rank Rank parameter for matrix factorization. 54 * @param maxIterations Maximum number of iterations. 55 * @param minResidue Residue required to terminate. 56 * @param mit Whether to terminate only when maxIterations is reached. 57 */ 58 template<typename MatType> Apply(const MatType &,const arma::sp_mat & cleanedData,const size_t rank,const size_t maxIterations,const double minResidue,const bool mit)59 void Apply(const MatType& /* data */, 60 const arma::sp_mat& cleanedData, 61 const size_t rank, 62 const size_t maxIterations, 63 const double minResidue, 64 const bool mit) 65 { 66 if (mit) 67 { 68 amf::MaxIterationTermination iter(maxIterations); 69 70 // Do singular value decomposition using the batch SVD algorithm. 71 amf::AMF<amf::MaxIterationTermination, amf::RandomInitialization, 72 amf::SVDBatchLearning> svdbatch(iter); 73 74 svdbatch.Apply(cleanedData, rank, w, h); 75 } 76 else 77 { 78 amf::SimpleResidueTermination srt(minResidue, maxIterations); 79 80 // Do singular value decomposition using the batch SVD algorithm. 81 amf::SVDBatchFactorizer<> svdbatch(srt); 82 83 svdbatch.Apply(cleanedData, rank, w, h); 84 } 85 } 86 87 /** 88 * Return predicted rating given user ID and item ID. 89 * 90 * @param user User ID. 91 * @param item Item ID. 92 */ GetRating(const size_t user,const size_t item) const93 double GetRating(const size_t user, const size_t item) const 94 { 95 double rating = arma::as_scalar(w.row(item) * h.col(user)); 96 return rating; 97 } 98 99 /** 100 * Get predicted ratings for a user. 101 * 102 * @param user User ID. 103 * @param rating Resulting rating vector. 104 */ GetRatingOfUser(const size_t user,arma::vec & rating) const105 void GetRatingOfUser(const size_t user, arma::vec& rating) const 106 { 107 rating = w * h.col(user); 108 } 109 110 /** 111 * Get the neighborhood and corresponding similarities for a set of users. 112 * 113 * @tparam NeighborSearchPolicy The policy to perform neighbor search. 114 * 115 * @param users Users whose neighborhood is to be computed. 116 * @param numUsersForSimilarity The number of neighbors returned for 117 * each user. 118 * @param neighborhood Neighbors represented by user IDs. 119 * @param similarities Similarity between each user and each of its 120 * neighbors. 121 */ 122 template<typename NeighborSearchPolicy> GetNeighborhood(const arma::Col<size_t> & users,const size_t numUsersForSimilarity,arma::Mat<size_t> & neighborhood,arma::mat & similarities) const123 void GetNeighborhood(const arma::Col<size_t>& users, 124 const size_t numUsersForSimilarity, 125 arma::Mat<size_t>& neighborhood, 126 arma::mat& similarities) const 127 { 128 // We want to avoid calculating the full rating matrix, so we will do 129 // nearest neighbor search only on the H matrix, using the observation that 130 // if the rating matrix X = W*H, then d(X.col(i), X.col(j)) = d(W H.col(i), 131 // W H.col(j)). This can be seen as nearest neighbor search on the H 132 // matrix with the Mahalanobis distance where M^{-1} = W^T W. So, we'll 133 // decompose M^{-1} = L L^T (the Cholesky decomposition), and then multiply 134 // H by L^T. Then we can perform nearest neighbor search. 135 arma::mat l = arma::chol(w.t() * w); 136 arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T. 137 138 // Temporarily store feature vector of queried users. 139 arma::mat query(stretchedH.n_rows, users.n_elem); 140 // Select feature vectors of queried users. 141 for (size_t i = 0; i < users.n_elem; ++i) 142 query.col(i) = stretchedH.col(users(i)); 143 144 NeighborSearchPolicy neighborSearch(stretchedH); 145 neighborSearch.Search( 146 query, numUsersForSimilarity, neighborhood, similarities); 147 } 148 149 //! Get the Item Matrix. W() const150 const arma::mat& W() const { return w; } 151 //! Get the User Matrix. H() const152 const arma::mat& H() const { return h; } 153 154 /** 155 * Serialization. 156 */ 157 template<typename Archive> serialize(Archive & ar,const unsigned int)158 void serialize(Archive& ar, const unsigned int /* version */) 159 { 160 ar & BOOST_SERIALIZATION_NVP(w); 161 ar & BOOST_SERIALIZATION_NVP(h); 162 } 163 164 private: 165 //! Item matrix. 166 arma::mat w; 167 //! User matrix. 168 arma::mat h; 169 }; 170 171 } // namespace cf 172 } // namespace mlpack 173 174 #endif 175