1 /**
2  * @file methods/cf/decomposition_policies/batch_svd_method.hpp
3  * @author Haritha Nair
4  *
5  * Implementation of the batch SVD method for use in Collaborative Filtering.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 
13 #ifndef MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_BATCH_SVD_METHOD_HPP
14 #define MLPACK_METHODS_CF_DECOMPOSITION_POLICIES_BATCH_SVD_METHOD_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <mlpack/methods/amf/amf.hpp>
18 #include <mlpack/methods/amf/update_rules/nmf_als.hpp>
19 #include <mlpack/methods/amf/termination_policies/simple_residue_termination.hpp>
20 #include <mlpack/methods/amf/termination_policies/max_iteration_termination.hpp>
21 
22 namespace mlpack {
23 namespace cf {
24 
25 /**
26  * Implementation of the Batch SVD policy to act as a wrapper when accessing
27  * Batch SVD from within CFType.
28  *
29  * An example of how to use BatchSVDPolicy in CF is shown below:
30  *
31  * @code
32  * extern arma::mat data; // data is a (user, item, rating) table.
33  * // Users for whom recommendations are generated.
34  * extern arma::Col<size_t> users;
35  * arma::Mat<size_t> recommendations; // Resulting recommendations.
36  *
37  * CFType<BatchSVDPolicy> cf(data);
38  *
39  * // Generate 10 recommendations for all users.
40  * cf.GetRecommendations(10, recommendations);
41  * @endcode
42  */
43 class BatchSVDPolicy
44 {
45  public:
46   /**
47    * Apply Collaborative Filtering to the provided data set using the
48    * batch SVD method.
49    *
50    * @param * (data) Data matrix: dense matrix (coordinate lists)
51    *    or sparse matrix(cleaned).
52    * @param cleanedData item user table in form of sparse matrix.
53    * @param rank Rank parameter for matrix factorization.
54    * @param maxIterations Maximum number of iterations.
55    * @param minResidue Residue required to terminate.
56    * @param mit Whether to terminate only when maxIterations is reached.
57    */
58   template<typename MatType>
Apply(const MatType &,const arma::sp_mat & cleanedData,const size_t rank,const size_t maxIterations,const double minResidue,const bool mit)59   void Apply(const MatType& /* data */,
60              const arma::sp_mat& cleanedData,
61              const size_t rank,
62              const size_t maxIterations,
63              const double minResidue,
64              const bool mit)
65   {
66     if (mit)
67     {
68       amf::MaxIterationTermination iter(maxIterations);
69 
70       // Do singular value decomposition using the batch SVD algorithm.
71       amf::AMF<amf::MaxIterationTermination, amf::RandomInitialization,
72           amf::SVDBatchLearning> svdbatch(iter);
73 
74       svdbatch.Apply(cleanedData, rank, w, h);
75     }
76     else
77     {
78       amf::SimpleResidueTermination srt(minResidue, maxIterations);
79 
80       // Do singular value decomposition using the batch SVD algorithm.
81       amf::SVDBatchFactorizer<> svdbatch(srt);
82 
83       svdbatch.Apply(cleanedData, rank, w, h);
84     }
85   }
86 
87   /**
88    * Return predicted rating given user ID and item ID.
89    *
90    * @param user User ID.
91    * @param item Item ID.
92    */
GetRating(const size_t user,const size_t item) const93   double GetRating(const size_t user, const size_t item) const
94   {
95     double rating = arma::as_scalar(w.row(item) * h.col(user));
96     return rating;
97   }
98 
99   /**
100    * Get predicted ratings for a user.
101    *
102    * @param user User ID.
103    * @param rating Resulting rating vector.
104    */
GetRatingOfUser(const size_t user,arma::vec & rating) const105   void GetRatingOfUser(const size_t user, arma::vec& rating) const
106   {
107     rating = w * h.col(user);
108   }
109 
110   /**
111    * Get the neighborhood and corresponding similarities for a set of users.
112    *
113    * @tparam NeighborSearchPolicy The policy to perform neighbor search.
114    *
115    * @param users Users whose neighborhood is to be computed.
116    * @param numUsersForSimilarity The number of neighbors returned for
117    *     each user.
118    * @param neighborhood Neighbors represented by user IDs.
119    * @param similarities Similarity between each user and each of its
120    *     neighbors.
121    */
122   template<typename NeighborSearchPolicy>
GetNeighborhood(const arma::Col<size_t> & users,const size_t numUsersForSimilarity,arma::Mat<size_t> & neighborhood,arma::mat & similarities) const123   void GetNeighborhood(const arma::Col<size_t>& users,
124                        const size_t numUsersForSimilarity,
125                        arma::Mat<size_t>& neighborhood,
126                        arma::mat& similarities) const
127   {
128     // We want to avoid calculating the full rating matrix, so we will do
129     // nearest neighbor search only on the H matrix, using the observation that
130     // if the rating matrix X = W*H, then d(X.col(i), X.col(j)) = d(W H.col(i),
131     // W H.col(j)).  This can be seen as nearest neighbor search on the H
132     // matrix with the Mahalanobis distance where M^{-1} = W^T W.  So, we'll
133     // decompose M^{-1} = L L^T (the Cholesky decomposition), and then multiply
134     // H by L^T. Then we can perform nearest neighbor search.
135     arma::mat l = arma::chol(w.t() * w);
136     arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T.
137 
138     // Temporarily store feature vector of queried users.
139     arma::mat query(stretchedH.n_rows, users.n_elem);
140     // Select feature vectors of queried users.
141     for (size_t i = 0; i < users.n_elem; ++i)
142       query.col(i) = stretchedH.col(users(i));
143 
144     NeighborSearchPolicy neighborSearch(stretchedH);
145     neighborSearch.Search(
146         query, numUsersForSimilarity, neighborhood, similarities);
147   }
148 
149   //! Get the Item Matrix.
W() const150   const arma::mat& W() const { return w; }
151   //! Get the User Matrix.
H() const152   const arma::mat& H() const { return h; }
153 
154   /**
155    * Serialization.
156    */
157   template<typename Archive>
serialize(Archive & ar,const unsigned int)158   void serialize(Archive& ar, const unsigned int /* version */)
159   {
160     ar & BOOST_SERIALIZATION_NVP(w);
161     ar & BOOST_SERIALIZATION_NVP(h);
162   }
163 
164  private:
165   //! Item matrix.
166   arma::mat w;
167   //! User matrix.
168   arma::mat h;
169 };
170 
171 } // namespace cf
172 } // namespace mlpack
173 
174 #endif
175