1 // This may look like C code, but it's really -*- C++ -*-
2 /*
3  * Copyright (C) 2011 Emweb bv, Herent, Belgium.
4  *
5  * See the LICENSE file for terms of use.
6  */
7 #ifndef WT_AUTH_DBO_USER_DATABASE_H_
8 #define WT_AUTH_DBO_USER_DATABASE_H_
9 
10 #include <Wt/Auth/AbstractUserDatabase.h>
11 #include <Wt/Auth/AuthService.h>
12 #include <Wt/Auth/Dbo/AuthInfo.h>
13 #include <Wt/WException.h>
14 #include <Wt/WLogger.h>
15 
16 namespace Wt {
17   namespace Auth {
18     namespace Dbo {
19 
20 /*! \class UserDatabase Wt/Auth/Dbo/UserDatabase.h
21  *  \brief A default implementation for an authentication user database.
22  *
23  * This is a template class, and needs as parameter the Dbo type which
24  * models the authentication information. A suitable implementation,
25  * which stores authentication information outside the "user" class,
26  * is provided by AuthInfo.
27  *
28  * \sa AuthInfo
29  *
30  * \ingroup auth
31  */
32 template <class DboType>
33 class UserDatabase : public AbstractUserDatabase
34 {
35   typedef typename DboType::AuthTokenType AuthTokenType;
36   typedef Wt::Dbo::collection< Wt::Dbo::ptr<AuthTokenType> > AuthTokens;
37 
38   typedef typename DboType::AuthIdentityType AuthIdentityType;
39   typedef Wt::Dbo::collection< Wt::Dbo::ptr<AuthIdentityType> > AuthIdentities;
40 
41 public:
42   /*! \brief Constructor
43    *
44    * The AuthService parameter is optional, and decides some of the
45    * UserDatabase's behaviour. Currently, this decides whether findWithIdentity()
46    * should be case sensitive or not. If the identity policy of the AuthService is
47    * EmailAddressIdentity, then findWithIdentity() will be case insensitive. Otherwise
48    * it is case sensitive.
49    */
50   UserDatabase(Wt::Dbo::Session& session, const AuthService *authService = 0)
session_(session)51     : session_(session),
52       newUserStatus_(AccountStatus::Normal),
53       authService_(authService),
54       maxAuthTokensPerUser_(50)
55   { }
56 
57   /*! \brief Sets the initial status for a new user.
58    *
59    * This status is set on a user that just registered.
60    *
61    * The default value is AccountStatus::Normal.
62    */
setNewUserStatus(AccountStatus status)63   void setNewUserStatus(AccountStatus status)
64   {
65     newUserStatus_ = status;
66   }
67 
startTransaction()68   virtual Transaction *startTransaction() override {
69     return new TransactionImpl(session_);
70   }
71 
72   /*! \brief Returns the %Dbo user type corresponding to an Auth::User.
73    */
find(const User & user)74   Wt::Dbo::ptr<DboType> find(const User& user) const {
75     getUser(user.id(), false);
76     return user_;
77   }
78 
79   /*! \brief Returns the Auth::User corresponding to a %Dbo user.
80    */
find(const Wt::Dbo::ptr<DboType> user)81   User find(const Wt::Dbo::ptr<DboType> user) const {
82     setUser(user);
83     return User(std::to_string(user_->id()), *this);
84   }
85 
findWithId(const std::string & id)86   virtual User findWithId(const std::string& id) const override {
87     getUser(id, false);
88 
89     if (user_)
90       return User(id, *this);
91     else
92       return User();
93   }
94 
findWithIdentity(const std::string & provider,const WString & identity)95   virtual User findWithIdentity(const std::string& provider,
96 				const WString& identity) const override {
97     if (userProvider_ != provider || userIdentity_ != identity) {
98       Wt::Dbo::Transaction t(session_);
99       Wt::Dbo::Query<Wt::Dbo::ptr<DboType> > query =
100 	  session_.query<Wt::Dbo::ptr<DboType> >(std::string() +
101 	       "select u from " + session_.tableNameQuoted<DboType>() + " u "
102 	       "join " + session_.tableNameQuoted<AuthIdentityType>() + " i "
103 	       "on u.id = i.\"" + session_.tableName<DboType>() + "_id\"")
104 	      .where("i.\"provider\" = ?").bind(provider);
105       if (authService_ && authService_->identityPolicy() == IdentityPolicy::EmailAddress) {
106 	query.where("lower(i.\"identity\") = lower(?)").bind(identity);
107       } else {
108 	query.where("i.\"identity\" = ?").bind(identity);
109       }
110       setUser(query.resultValue());
111       t.commit();
112     }
113 
114     if (user_) {
115       userProvider_ = provider;
116       userIdentity_ = identity;
117       return User(std::to_string(user_.id()), *this);
118     } else
119       return User();
120   }
121 
identity(const User & user,const std::string & provider)122   virtual WString identity(const User& user,
123 			   const std::string& provider) const override {
124     WithUser find(*this, user);
125 
126     AuthIdentities c
127       = user_->authIdentities().find().where("\"provider\" = ?").bind(provider);
128 
129     typename AuthIdentities::const_iterator i = c.begin();
130 
131     if (i != c.end())
132       return (*i)->identity();
133     else
134       return WString::Empty;
135   }
136 
removeIdentity(const User & user,const std::string & provider)137   virtual void removeIdentity(const User& user,
138 			      const std::string& provider) override {
139     Wt::Dbo::Transaction t(session_);
140 
141     session_.execute
142       (std::string() +
143        "delete from " + session_.tableNameQuoted<AuthIdentityType>() +
144        " where \"" + session_.tableName<DboType>() + "_id\" = ?"
145        " and \"provider\" = ?").bind(user.id()).bind(provider);
146 
147     t.commit();
148   }
149 
registerNew()150   virtual User registerNew() override {
151     auto user = std::make_unique<DboType>();
152     user->setStatus(newUserStatus_);
153     setUser(session_.add(std::move(user)));
154     user_.flush();
155     return User(std::to_string(user_.id()), *this);
156   }
157 
deleteUser(const User & user)158   virtual void deleteUser(const User& user) override {
159     Wt::Dbo::Transaction t(session_);
160     Wt::Dbo::ptr<DboType> u = find(user);
161     u.remove();
162     t.commit();
163   }
164 
status(const User & user)165   virtual AccountStatus status(const User& user) const override {
166     WithUser find(*this, user);
167     return user_->status();
168   }
169 
setStatus(const User & user,AccountStatus status)170   virtual void setStatus(const User& user, AccountStatus status) override {
171     WithUser find(*this, user);
172     user_.modify()->setStatus(status);
173   }
174 
setPassword(const User & user,const PasswordHash & password)175   virtual void setPassword(const User& user, const PasswordHash& password) override {
176     WithUser find(*this, user);
177     user_.modify()->setPassword(password.value(),
178 				password.function(),
179 				password.salt());
180   }
181 
password(const User & user)182   virtual PasswordHash password(const User& user) const override {
183     WithUser find(*this, user);
184     return PasswordHash(user_->passwordMethod(), user_->passwordSalt(),
185 			user_->passwordHash());
186   }
187 
addIdentity(const User & user,const std::string & provider,const WT_USTRING & identity)188   virtual void addIdentity(const User& user, const std::string& provider,
189 			   const WT_USTRING& identity) override {
190     WithUser find(*this, user);
191 
192     if (session_.find<AuthIdentityType>()
193 	.where("\"identity\" = ?").bind(identity)
194 	.where("\"provider\" = ?").bind(provider).resultList().size() != 0) {
195       Wt::log("error") << "cannot add identity " << provider
196 		       << ":'" << identity << "': already exists";
197       return;
198     }
199 
200     /*
201      * It's okay to have more than one identity from that provider
202      */
203     user_.modify()->authIdentities().insert
204       (Wt::Dbo::ptr<AuthIdentityType>(
205          std::make_unique<AuthIdentityType>(provider, identity)));
206   }
207 
setIdentity(const User & user,const std::string & provider,const WT_USTRING & identity)208   virtual void setIdentity(const User& user, const std::string& provider,
209 			   const WT_USTRING& identity) override {
210     WithUser find(*this, user);
211 
212     AuthIdentities c
213       = user_->authIdentities().find().where("\"provider\" = ?").bind(provider);
214 
215     typename AuthIdentities::const_iterator i = c.begin();
216 
217     if (i != c.end())
218       i->modify()->setIdentity(identity);
219     else
220       user_.modify()->authIdentities().insert
221         (Wt::Dbo::ptr<AuthIdentityType>(
222            std::make_unique<AuthIdentityType>(provider, identity)));
223   }
224 
setEmail(const User & user,const std::string & address)225   virtual bool setEmail(const User& user, const std::string& address) override {
226     WithUser find(*this, user);
227 
228     if (session_.find<DboType>().where("lower(\"email\") = lower(?)")
229 	.bind(address).resultList().size() != 0)
230       return false;
231 
232     user_.modify()->setEmail(address);
233 
234     return true;
235   }
236 
email(const User & user)237   virtual std::string email(const User& user) const override {
238     WithUser find(*this, user);
239     return user_->email();
240   }
241 
setUnverifiedEmail(const User & user,const std::string & address)242   virtual void setUnverifiedEmail(const User& user,
243 				  const std::string& address) override {
244     WithUser find(*this, user);
245     user_.modify()->setUnverifiedEmail(address);
246   }
247 
unverifiedEmail(const User & user)248   virtual std::string unverifiedEmail(const User& user) const override {
249     WithUser find(*this, user);
250     return user_->unverifiedEmail();
251   }
252 
findWithEmail(const std::string & address)253   virtual User findWithEmail(const std::string& address) const override {
254     Wt::Dbo::Transaction t(session_);
255     setUser(session_.find<DboType>().where("lower(\"email\") = lower(?)").bind(address));
256     t.commit();
257 
258     if (user_)
259       return User(std::to_string(user_.id()), *this);
260     else
261       return User();
262   }
263 
setEmailToken(const User & user,const Token & token,EmailTokenRole role)264   virtual void setEmailToken(const User& user, const Token& token,
265 			     EmailTokenRole role) override {
266     WithUser find(*this, user);
267     user_.modify()->setEmailToken(token.hash(), token.expirationTime(), role);
268   }
269 
emailToken(const User & user)270   virtual Token emailToken(const User& user) const override {
271     WithUser find(*this, user);
272     return Token(user_->emailToken(), user_->emailTokenExpires());
273   }
274 
emailTokenRole(const User & user)275   virtual EmailTokenRole emailTokenRole(const User& user) const override {
276     WithUser find(*this, user);
277     return user_->emailTokenRole();
278   }
279 
findWithEmailToken(const std::string & hash)280   virtual User findWithEmailToken(const std::string& hash) const override {
281     Wt::Dbo::Transaction t(session_);
282     setUser(session_.find<DboType>()
283 	    .where("\"email_token\" = ?").bind(hash));
284     t.commit();
285 
286     if (user_)
287       return User(std::to_string(user_.id()), *this);
288     else
289       return User();
290   }
291 
addAuthToken(const User & user,const Token & token)292   virtual void addAuthToken(const User& user, const Token& token) override {
293     WithUser find(*this, user);
294 
295     /*
296      * This should be statistically very unlikely but also a big
297      * security problem if we do not detect it ...
298      */
299     if (session_.find<AuthTokenType>().where("\"value\" = ?")
300 	.bind(token.hash())
301 	.resultList().size() > 0)
302       throw WException("Token hash collision");
303 
304     /*
305      * Prevent a user from piling up the database with tokens
306      */
307     size_t tokens_number = user_->authTokens().size();
308     if (tokens_number >= maxAuthTokensPerUser_) {
309       // remove so many tokens, that their number
310       // would be (maxAuthTokensPerUser_ - 1)
311       int tokens_to_remove = tokens_number - (maxAuthTokensPerUser_ - 1);
312       // remove the first token(s) to expire
313       Wt::Dbo::collection<Wt::Dbo::ptr<AuthTokenType> > earliest_tokens =
314         user_->authTokens().find().orderBy("expires").limit(tokens_to_remove);
315       std::vector<Wt::Dbo::ptr<AuthTokenType> > earliest_tokens_v(
316         earliest_tokens.begin(), earliest_tokens.end());
317 
318       for (auto& token : earliest_tokens_v)
319 	token.remove();
320     }
321 
322     user_.modify()->authTokens().insert
323       (Wt::Dbo::ptr<AuthTokenType>
324        (std::make_unique<AuthTokenType>(token.hash(), token.expirationTime())));
325   }
326 
removeAuthToken(const User & user,const std::string & hash)327   virtual void removeAuthToken(const User& user, const std::string& hash) override {
328     WithUser find(*this, user);
329 
330     for (typename AuthTokens::const_iterator i = user_->authTokens().begin();
331 	 i != user_->authTokens().end(); ++i) {
332       Wt::Dbo::ptr<AuthTokenType> t = *i;
333       if (t->value() == hash) {
334 	t.remove();
335 	break;
336       }
337     }
338   }
339 
updateAuthToken(const User & user,const std::string & hash,const std::string & newHash)340   virtual int updateAuthToken(const User& user, const std::string& hash,
341 			      const std::string& newHash) override {
342     WithUser find(*this, user);
343 
344     for (typename AuthTokens::const_iterator i = user_->authTokens().begin();
345 	 i != user_->authTokens().end(); ++i) {
346       Wt::Dbo::ptr<AuthTokenType> t = *i;
347       if (t->value() == hash) {
348 	t.modify()->setValue(newHash);
349 	return std::max(0, WDateTime::currentDateTime().secsTo(t->expires()));
350       }
351     }
352 
353     return 0;
354   }
355 
findWithAuthToken(const std::string & hash)356   virtual User findWithAuthToken(const std::string& hash) const override {
357     Wt::Dbo::Transaction t(session_);
358     setUser(session_.query< Wt::Dbo::ptr<DboType> >
359 	    (std::string() +
360 	     "select u from " + session_.tableNameQuoted<DboType>() + " u "
361 	     "join " + session_.tableNameQuoted<AuthTokenType>() + " t "
362 	     "on u.id = t.\"" + session_.tableName<DboType>() + "_id\"")
363 	    .where("t.\"value\" = ?").bind(hash)
364 	    .where("t.\"expires\" > ?").bind(WDateTime::currentDateTime()));
365     t.commit();
366 
367     if (user_)
368       return User(std::to_string(user_.id()), *this);
369     else
370       return User();
371   }
372 
setFailedLoginAttempts(const User & user,int count)373   virtual void setFailedLoginAttempts(const User& user, int count) override {
374     WithUser find(*this, user, true);
375     return user_.modify()->setFailedLoginAttempts(count);
376   }
377 
failedLoginAttempts(const User & user)378   virtual int failedLoginAttempts(const User& user) const override {
379     WithUser find(*this, user, true);
380     return user_->failedLoginAttempts();
381   }
382 
setLastLoginAttempt(const User & user,const WDateTime & t)383   virtual void setLastLoginAttempt(const User& user, const WDateTime& t) override {
384     WithUser find(*this, user, true);
385     return user_.modify()->setLastLoginAttempt(t);
386   }
387 
lastLoginAttempt(const User & user)388   virtual WDateTime lastLoginAttempt(const User& user) const override {
389     WithUser find(*this, user, true);
390     return user_->lastLoginAttempt();
391   }
392 
393   /*! \brief Returns max number of tokens user can have in the database
394    *
395    * Default value is 50.
396    */
maxAuthTokensPerUser()397   unsigned maxAuthTokensPerUser() const {
398     return maxAuthTokensPerUser_;
399   }
400 
401   /*! \brief Sets max number of tokens user can have in the database
402    */
setMaxAuthTokensPerUser(unsigned maxAuthTokensPerUser)403   void setMaxAuthTokensPerUser(unsigned maxAuthTokensPerUser) {
404     maxAuthTokensPerUser_ = maxAuthTokensPerUser;
405   }
406 
407 private:
408   Wt::Dbo::Session& session_;
409   AccountStatus newUserStatus_;
410   const AuthService *authService_;
411   mutable Wt::Dbo::ptr<DboType> user_;
412   mutable std::string userProvider_;
413   mutable Wt::WString userIdentity_;
414   unsigned maxAuthTokensPerUser_;
415 
416   struct WithUser {
417     WithUser(const UserDatabase<DboType>& self, const User& user,
418 	     bool reread = false)
419       : transaction(self.session_)
420     {
421       self.getUser(user.id(), reread);
422       if (!self.user_)
423 	throw WException("Invalid user");
424     }
425 
~WithUserWithUser426     ~WithUser() {
427       transaction.commit();
428     }
429 
430     Wt::Dbo::Transaction transaction;
431   };
432 
getUser(const std::string & id,bool reread)433   void getUser(const std::string& id, bool reread) const {
434     if (!user_ || std::to_string(user_.id()) != id) {
435       Wt::Dbo::Transaction t(session_);
436       setUser(session_.load<DboType>(std::stoll(id)));
437       t.commit();
438     } else
439       if (reread && !user_.isDirty())
440 	user_.reread();
441   }
442 
setUser(Wt::Dbo::ptr<DboType> user)443   void setUser(Wt::Dbo::ptr<DboType> user) const {
444     user_ = user;
445     userProvider_.clear();
446     userIdentity_ = WString::Empty;
447   }
448 
449   struct TransactionImpl final : public Transaction, public Wt::Dbo::Transaction
450   {
TransactionImplfinal451     TransactionImpl(Wt::Dbo::Session& session)
452       : Wt::Dbo::Transaction(session)
453     { }
454 
~TransactionImplfinal455     virtual ~TransactionImpl()
456     { }
457 
commitfinal458     virtual void commit() override
459     {
460       Wt::Dbo::Transaction::commit();
461     }
462 
rollbackfinal463     virtual void rollback() override
464     {
465       Wt::Dbo::Transaction::rollback();
466     }
467   };
468 };
469 
470     }
471   }
472 }
473 
474 #endif // WT_AUTH_DBO_USER_DATABASE
475