1 // This may look like C code, but it's really -*- C++ -*-
2 /*
3  * Copyright (C) 2008 Emweb bv, Herent, Belgium.
4  *
5  * See the LICENSE file for terms of use.
6  */
7 
8 #include "Wt/Dbo/Call.h"
9 #include "Wt/Dbo/Exception.h"
10 #include "Wt/Dbo/Logger.h"
11 #include "Wt/Dbo/Session.h"
12 #include "Wt/Dbo/SqlConnection.h"
13 #include "Wt/Dbo/SqlConnectionPool.h"
14 #include "Wt/Dbo/SqlStatement.h"
15 #include "Wt/Dbo/StdSqlTraits.h"
16 #include "Wt/Dbo/StringStream.h"
17 
18 #include <iostream>
19 #include <vector>
20 #include <string>
21 
22 #include <boost/multi_index_container.hpp>
23 #include <boost/multi_index/hashed_index.hpp>
24 #include <boost/multi_index/sequenced_index.hpp>
25 #include <boost/multi_index/member.hpp>
26 
27 namespace Wt {
28   namespace Dbo {
29 
30 LOGGER("Dbo.Session");
31 
32     namespace Impl {
33 
34 struct MetaDboBaseSet : public boost::multi_index::multi_index_container<
35     MetaDboBase *,
36     boost::multi_index::indexed_by<
37       boost::multi_index::sequenced<>,
38       boost::multi_index::hashed_unique
39       <boost::multi_index::identity<MetaDboBase *> >
40     >
41   >
42 { };
43 
replace(std::string & s,char c,const std::string & r)44 std::string& replace(std::string& s, char c, const std::string& r)
45 {
46   std::string::size_type p = 0;
47 
48   while ((p = s.find(c, p)) != std::string::npos) {
49     s.replace(p, 1, r);
50     p += r.length();
51   }
52 
53   return s;
54 }
55 
quoteSchemaDot(const std::string & table)56 std::string quoteSchemaDot(const std::string& table) {
57   std::string result = table;
58   replace(result, '.', "\".\"");
59   return result;
60 }
61 
SetInfo(const char * aTableName,RelationType aType,const std::string & aJoinName,const std::string & aJoinSelfId,int someFkConstraints)62 SetInfo::SetInfo(const char *aTableName,
63 		 RelationType aType,
64 		 const std::string& aJoinName,
65 		 const std::string& aJoinSelfId,
66 		 int someFkConstraints)
67   : tableName(aTableName),
68     joinName(aJoinName),
69     joinSelfId(aJoinSelfId),
70     flags(0),
71     type(aType),
72     fkConstraints(someFkConstraints)
73 { }
74 
MappingInfo()75 Impl::MappingInfo::MappingInfo()
76   : initialized_(false)
77 { }
78 
~MappingInfo()79 MappingInfo::~MappingInfo()
80 { }
81 
init(Session & session)82 void MappingInfo::init(Session& session)
83 {
84   throw Exception("Not to be done.");
85 }
86 
dropTable(Session & session,std::set<std::string> & tablesDropped)87 void MappingInfo::dropTable(Session& session,
88 				     std::set<std::string>& tablesDropped)
89 {
90   throw Exception("Not to be done.");
91 }
92 
rereadAll()93 void MappingInfo::rereadAll()
94 {
95   throw Exception("Not to be done.");
96 }
97 
create(Session & session)98 MetaDboBase *MappingInfo::create(Session& session)
99 {
100   throw Exception("Not to be done.");
101 }
102 
load(Session & session,MetaDboBase * obj)103 void MappingInfo::load(Session& session, MetaDboBase *obj)
104 {
105   throw Exception("Not to be done.");
106 }
107 
load(Session & session,SqlStatement * statement,int & column)108 MetaDboBase *MappingInfo::load(Session& session, SqlStatement *statement,
109 			       int& column)
110 {
111   throw Exception("Not to be done.");
112 }
113 
primaryKeys()114 std::string MappingInfo::primaryKeys() const
115 {
116   if (surrogateIdFieldName)
117     return std::string("\"") + surrogateIdFieldName + "\"";
118   else {
119     std::stringstream result;
120 
121     bool firstField = true;
122     for (unsigned i = 0; i < fields.size(); ++i)
123       if (fields[i].isIdField()) {
124 	if (!firstField)
125 	  result << ", ";
126 	result << "\"" << fields[i].name() << "\"";
127 	firstField = false;
128       }
129 
130     return result.str();
131   }
132 }
133 
134     } // end namespace Impl
135 
JoinId(const std::string & aJoinIdName,const std::string & aTableIdName,const std::string & aSqlType)136 Session::JoinId::JoinId(const std::string& aJoinIdName,
137 			const std::string& aTableIdName,
138 			const std::string& aSqlType)
139   : joinIdName(aJoinIdName),
140     tableIdName(aTableIdName),
141     sqlType(aSqlType)
142 { }
143 
Session()144 Session::Session()
145   : schemaInitialized_(false),
146     //useRowsFromTo_(false),
147     requireSubqueryAlias_(false),
148     dirtyObjects_(new Impl::MetaDboBaseSet()),
149     connection_(nullptr),
150     connectionPool_(nullptr),
151     transaction_(nullptr),
152     flushMode_(FlushMode::Auto)
153 { }
154 
~Session()155 Session::~Session()
156 {
157   if (!dirtyObjects_->empty()) {
158     LOG_WARN("Session exiting with " << dirtyObjects_->size() << " dirty objects");
159   }
160 
161   while (!dirtyObjects_->empty()) {
162     MetaDboBase *b = *dirtyObjects_->begin();
163     discardChanges(b);
164   }
165 
166   dirtyObjects_->clear();
167   delete dirtyObjects_;
168 
169   for (ClassRegistry::iterator i = classRegistry_.begin();
170        i != classRegistry_.end(); ++i)
171     delete i->second;
172 }
173 
setConnection(std::unique_ptr<SqlConnection> connection)174 void Session::setConnection(std::unique_ptr<SqlConnection> connection)
175 {
176   connection_ = std::move(connection);
177 }
178 
setConnectionPool(SqlConnectionPool & pool)179 void Session::setConnectionPool(SqlConnectionPool &pool)
180 {
181   connectionPool_ = &pool;
182 }
183 
connection(bool openTransaction)184 SqlConnection *Session::connection(bool openTransaction)
185 {
186   if (!transaction_)
187     throw Exception("Operation requires an active transaction");
188 
189   if (openTransaction)
190     transaction_->open();
191 
192   return transaction_->connection_.get();
193 }
194 
useConnection()195 std::unique_ptr<SqlConnection> Session::useConnection()
196 {
197   if (connectionPool_)
198     return connectionPool_->getConnection();
199   else
200     return std::move(connection_);
201 }
202 
returnConnection(std::unique_ptr<SqlConnection> connection)203 void Session::returnConnection(std::unique_ptr<SqlConnection> connection)
204 {
205   if (connectionPool_)
206     connectionPool_->returnConnection(std::move(connection));
207   else
208     connection_ = std::move(connection);
209 }
210 
discardChanges(MetaDboBase * obj)211 void Session::discardChanges(MetaDboBase *obj)
212 {
213   Impl::MetaDboBaseSet::nth_index<1>::type& setIndex = dirtyObjects_->get<1>();
214 
215   if (setIndex.erase(obj) > 0)
216     obj->decRef();
217 
218   // FIXME what about Transaction.objects_ ?
219 }
220 
execute(const std::string & sql)221 Call Session::execute(const std::string& sql)
222 {
223   initSchema();
224 
225   if (!transaction_)
226     throw Exception("Dbo execute(): no active transaction");
227 
228   return Call(*this, sql);
229 }
230 
initSchema()231 void Session::initSchema() const
232 {
233   if (schemaInitialized_)
234     return;
235 
236   Session *self = const_cast<Session *>(this);
237   self->schemaInitialized_ = true;
238 
239   Transaction t(*self);
240 
241   SqlConnection *conn = self->connection(false);
242   longlongType_ = sql_value_traits<long long>::type(conn, 0);
243   intType_ = sql_value_traits<int>::type(conn, 0);
244   haveSupportUpdateCascade_ = conn->supportUpdateCascade();
245   limitQueryMethod_ = conn->limitQueryMethod();
246   requireSubqueryAlias_ = conn->requireSubqueryAlias();
247 
248   for (ClassRegistry::const_iterator i = classRegistry_.begin();
249        i != classRegistry_.end(); ++i)
250     i->second->init(*self);
251 
252   for (ClassRegistry::const_iterator i = classRegistry_.begin();
253        i != classRegistry_.end(); ++i)
254     self->resolveJoinIds(i->second);
255 
256   for (ClassRegistry::const_iterator i = classRegistry_.begin();
257        i != classRegistry_.end(); ++i)
258     self->prepareStatements(i->second);
259 
260   t.commit();
261 }
262 
prepareStatements(Impl::MappingInfo * mapping)263 void Session::prepareStatements(Impl::MappingInfo *mapping)
264 {
265   std::stringstream sql;
266 
267   std::string table = Impl::quoteSchemaDot(mapping->tableName);
268 
269   /*
270    * SqlInsert
271    */
272   sql << "insert into \"" << table << "\" (";
273 
274   bool firstField = true;
275 
276   if (mapping->versionFieldName) {
277     sql << "\"" << mapping->versionFieldName << "\"";
278     firstField = false;
279   }
280 
281   for (unsigned i = 0; i < mapping->fields.size(); ++i) {
282     if (!firstField)
283       sql << ", ";
284     sql << "\"" << mapping->fields[i].name() << "\"";
285     firstField = false;
286   }
287 
288   sql << ")";
289 
290   std::unique_ptr<SqlConnection> connPtr;
291   SqlConnection *conn;
292   if (transaction_)
293     conn = transaction_->connection_.get();
294   else {
295     connPtr = useConnection();
296     conn = connPtr.get();
297   }
298 
299   if (mapping->surrogateIdFieldName) {
300     sql << conn->autoincrementInsertInfix(mapping->surrogateIdFieldName);
301   }
302 
303   sql << " values (";
304 
305   firstField = true;
306   if (mapping->versionFieldName) {
307     sql << "?";
308     firstField = false;
309   }
310 
311   for (unsigned i = 0; i < mapping->fields.size(); ++i) {
312     if (!firstField)
313       sql << ", ";
314     sql << "?";
315     firstField = false;
316   }
317 
318   sql << ")";
319 
320   if (mapping->surrogateIdFieldName) {
321     sql << conn->autoincrementInsertSuffix(mapping->surrogateIdFieldName);
322   }
323 
324   if (!transaction_)
325     returnConnection(std::move(connPtr));
326 
327   mapping->statements.push_back(sql.str()); // SqlInsert
328 
329   /*
330    * SqlUpdate
331    */
332 
333   sql.str("");
334 
335   sql << "update \"" << table << "\" set ";
336 
337   firstField = true;
338 
339   if (mapping->versionFieldName) {
340     sql << "\"" << mapping->versionFieldName << "\" = ?";
341     firstField = false;
342   }
343 
344   for (unsigned i = 0; i < mapping->fields.size(); ++i) {
345     if (!firstField)
346       sql << ", ";
347     sql << "\"" << mapping->fields[i].name() << "\" = ?";
348 
349     firstField = false;
350   }
351 
352   sql << " where ";
353 
354   std::string idCondition;
355   std::string modifyIdCondition;
356 
357   if (!mapping->surrogateIdFieldName) {
358     firstField = true;
359 
360     for (unsigned i = 0; i < mapping->fields.size(); ++i) {
361       if (mapping->fields[i].isNaturalIdField()) {
362 	if (!firstField)
363 	  idCondition += " and ";
364 	idCondition += "\"" + mapping->fields[i].name() + "\" = ?";
365 
366 	firstField = false;
367       }
368     }
369 
370     if (firstField)
371       throw Exception("Table " + std::string(mapping->tableName)
372 		      + " is missing a natural id defined with Wt::Dbo::id()");
373   } else
374     idCondition
375       += std::string() + "\"" + mapping->surrogateIdFieldName + "\" = ?";
376 
377   modifyIdCondition = idCondition;
378   for (unsigned i = 0; i < mapping->fields.size(); ++i) {
379     if (mapping->fields[i].isAuxIdField()) {
380       modifyIdCondition += " and ";
381       modifyIdCondition += "\"" + mapping->fields[i].name() + "\" = ?";
382     }
383   }
384 
385   mapping->idCondition = idCondition;
386 
387   sql << modifyIdCondition;
388 
389   if (mapping->versionFieldName)
390     sql << " and \"" << mapping->versionFieldName << "\" = ?";
391 
392   mapping->statements.push_back(sql.str()); // SqlUpdate
393 
394   /*
395    * SqlDelete
396    */
397 
398   sql.str("");
399 
400   sql << "delete from \"" << table << "\" where " << modifyIdCondition;
401 
402   mapping->statements.push_back(sql.str()); // SqlDelete
403 
404   /*
405    * SqlDeleteVersioned
406    */
407   if (mapping->versionFieldName)
408     sql << " and \"" << mapping->versionFieldName << "\" = ?";
409 
410   mapping->statements.push_back(sql.str()); // SqlDeleteVersioned
411 
412   /*
413    * SelectedById
414    */
415 
416   sql.str("");
417 
418   sql << "select ";
419 
420   firstField = true;
421   if (mapping->versionFieldName) {
422     sql << "\"" << mapping->versionFieldName << "\"";
423     firstField = false;
424   }
425 
426   for (unsigned i = 0; i < mapping->fields.size(); ++i) {
427     if (!firstField)
428       sql << ", ";
429     sql << "\"" << mapping->fields[i].name() << "\"";
430     firstField = false;
431   }
432 
433   sql << " from \"" << table << "\" where " << idCondition;
434 
435   mapping->statements.push_back(sql.str()); // SelectById
436 
437   /*
438    * Collections SQL
439    */
440   for (unsigned i = 0; i < mapping->sets.size(); ++i) {
441     const Impl::SetInfo& info = mapping->sets[i];
442 
443     sql.str("");
444 
445     Impl::MappingInfo *otherMapping = getMapping(info.tableName);
446 
447     // select [surrogate id,] version, ... from other
448 
449     sql << "select ";
450 
451     firstField = true;
452     if (otherMapping->surrogateIdFieldName) {
453       sql << "\"" << otherMapping->surrogateIdFieldName << "\"";
454       firstField = false;
455     }
456 
457     if (otherMapping->versionFieldName) {
458       if (!firstField)
459 	sql << ", ";
460       sql << "\"" << otherMapping->versionFieldName << "\"";
461       firstField = false;
462     }
463 
464     std::string fkConditions;
465     std::string other;
466 
467     for (unsigned i = 0; i < otherMapping->fields.size(); ++i) {
468       if (!firstField)
469 	sql << ", ";
470       firstField = false;
471 
472       const FieldInfo& field = otherMapping->fields[i];
473       sql << "\"" << field.name() << "\"";
474 
475       if (field.isForeignKey()
476 	  && field.foreignKeyTable() == mapping->tableName) {
477 	if (field.foreignKeyName() == info.joinName) {
478 	  if (!fkConditions.empty())
479 	    fkConditions += " and ";
480 	  fkConditions += std::string("\"") + field.name() + "\" = ?";
481 	} else {
482 	  if (!other.empty())
483 	    other += " and ";
484 
485 	  other += "'" + field.foreignKeyName() + "'";
486 	}
487       }
488     }
489 
490     sql << " from \"" << Impl::quoteSchemaDot(otherMapping->tableName);
491 
492     switch (info.type) {
493     case ManyToOne:
494       // where joinfield_id(s) = ?
495 
496       if (fkConditions.empty()) {
497 	std::string msg = std::string()
498 	  + "Relation mismatch for table '" + mapping->tableName
499 	  + "': no matching belongsTo() found in table '"
500 	  + otherMapping->tableName + "' with name '" + info.joinName
501 	  + "'";
502 
503 	if (!other.empty())
504 	  msg += ", but did find with name " + other + "?";
505 
506 	throw Exception(msg);
507       }
508 
509       sql << "\" where " << fkConditions;
510 
511       mapping->statements.push_back(sql.str());
512       break;
513     case ManyToMany:
514       // (1) select for collection
515 
516       //     join "joinName" on "joinName"."joinId(s) = this."id(s)
517       //     where joinfield_id(s) = ?
518 
519       std::string joinName = Impl::quoteSchemaDot(info.joinName);
520       std::string tableName = Impl::quoteSchemaDot(info.tableName);
521 
522       sql << "\" join \"" << joinName
523 	  << "\" on ";
524 
525       std::vector<JoinId> otherJoinIds
526 	= getJoinIds(otherMapping, info.joinOtherId, info.flags & Impl::SetInfo::LiteralOtherId);
527 
528       if (otherJoinIds.size() > 1)
529 	sql << "(";
530 
531       for (unsigned i = 0; i < otherJoinIds.size(); ++i) {
532 	if (i != 0)
533 	  sql << " and ";
534 	sql << "\"" << joinName << "\".\"" << otherJoinIds[i].joinIdName
535 	    << "\" = \""
536 	    << tableName << "\".\"" << otherJoinIds[i].tableIdName << "\"";
537       }
538 
539       if (otherJoinIds.size() > 1)
540 	sql << ")";
541 
542       sql << " where ";
543 
544       std::vector<JoinId> selfJoinIds
545 	= getJoinIds(mapping, info.joinSelfId, info.flags & Impl::SetInfo::LiteralSelfId);
546 
547       for (unsigned i = 0; i < selfJoinIds.size(); ++i) {
548 	if (i != 0)
549 	  sql << " and ";
550 	sql << "\"" << joinName << "\".\"" << selfJoinIds[i].joinIdName
551 	    << "\" = ?";
552       }
553 
554       mapping->statements.push_back(sql.str());
555 
556       // (2) insert into collection
557 
558       sql.str("");
559 
560       sql << "insert into \"" << joinName
561 	  << "\" (";
562 
563       firstField = true;
564       for (unsigned i = 0; i < selfJoinIds.size(); ++i) {
565 	if (!firstField)
566 	  sql << ", ";
567 	firstField = false;
568 
569 	sql << "\"" << selfJoinIds[i].joinIdName << "\"";
570       }
571 
572       for (unsigned i = 0; i < otherJoinIds.size(); ++i) {
573 	if (!firstField)
574 	  sql << ", ";
575 	firstField = false;
576 
577 	sql << "\"" << otherJoinIds[i].joinIdName << "\"";
578       }
579 
580       sql << ") values (";
581 
582       for (unsigned i = 0; i < selfJoinIds.size() + otherJoinIds.size(); ++i) {
583 	if (i != 0)
584 	  sql << ", ";
585 	sql << "?";
586       }
587 
588       sql << ")";
589 
590       mapping->statements.push_back(sql.str());
591 
592       // (3) delete from collections
593 
594       sql.str("");
595 
596       sql << "delete from \"" << joinName << "\" where ";
597 
598       firstField = true;
599       for (unsigned i = 0; i < selfJoinIds.size(); ++i) {
600 	if (!firstField)
601 	  sql << " and ";
602 	firstField = false;
603 
604 	sql << "\"" << selfJoinIds[i].joinIdName << "\" = ?";
605       }
606 
607       for (unsigned i = 0; i < otherJoinIds.size(); ++i) {
608 	if (!firstField)
609 	  sql << " and ";
610 	firstField = false;
611 
612 	sql << "\"" << otherJoinIds[i].joinIdName << "\" = ?";
613       }
614 
615       mapping->statements.push_back(sql.str());
616     }
617   }
618 }
619 
executeSql(std::vector<std::string> & sql,std::ostream * sout)620 void Session::executeSql(std::vector<std::string>& sql, std::ostream *sout)
621 {
622   for (unsigned i = 0; i < sql.size(); i++)
623     if (sout)
624       *sout << sql[i] << ";\n";
625     else
626       connection(true)->executeSql(sql[i]);
627 }
628 
executeSql(std::stringstream & sql,std::ostream * sout)629 void Session::executeSql(std::stringstream& sql, std::ostream *sout)
630 {
631   if (sout)
632     *sout << sql.str() << ";\n";
633   else
634     connection(true)->executeSql(sql.str());
635 }
636 
constraintName(const char * tableName,std::string foreignKeyName)637 std::string Session::constraintName(const char *tableName,
638                            std::string foreignKeyName)
639 {
640   std::stringstream ans;
641   ans << "\"fk_"<<tableName << "_" << foreignKeyName << "\"";
642   return ans.str();
643 }
644 
645 
646 /*
647 void Session::mergeDuplicates(Impl::MappingInfo *mapping)
648 {
649   for (unsigned i = 0; i < mapping->fields.size(); ++i) {
650     FieldInfo& f = mapping->fields[i];
651     for (unsigned j = i + 1; j < mapping->fields.size(); ++j) {
652       FieldInfo& f2 = mapping->fields[j];
653       if (f.name() == f2.name()) {
654 	if (f.sqlType() != f2.sqlType())
655 	  throw Exception("Table: " + mapping->tableName + ": field '"
656 			  + f.name() + "' mapped multiple times");
657 			  "for " + mapping->tableName + "."
658 			  + set.joinName);
659 
660       }
661     }
662   }
663 }
664 */
665 
resolveJoinIds(Impl::MappingInfo * mapping)666 void Session::resolveJoinIds(Impl::MappingInfo *mapping)
667 {
668   for (unsigned i = 0; i < mapping->sets.size(); ++i) {
669     Impl::SetInfo& set = mapping->sets[i];
670 
671     if (set.type == ManyToMany) {
672       Impl::MappingInfo *other = getMapping(set.tableName);
673 
674       for (unsigned j = 0; j < other->sets.size(); ++j) {
675 	const Impl::SetInfo& otherSet = other->sets[j];
676 
677 	if (otherSet.joinName == set.joinName) {
678 	  // second check make sure we find the other id if Many-To-Many between
679 	  // same table
680 	  if (mapping != other || i != j) {
681 	    set.joinOtherId = otherSet.joinSelfId;
682 	    set.otherFkConstraints = otherSet.fkConstraints;
683 	    if (otherSet.flags & Impl::SetInfo::LiteralSelfId)
684 	      set.flags |= Impl::SetInfo::LiteralOtherId;
685 	    break;
686 	  }
687 	}
688       }
689     }
690   }
691 }
692 
tableCreationSql()693 std::string Session::tableCreationSql()
694 {
695   initSchema();
696 
697   std::stringstream sout;
698 
699   Transaction t(*this);
700 
701   std::set<std::string> tablesCreated;
702 
703   for (ClassRegistry::iterator i = classRegistry_.begin();
704        i != classRegistry_.end(); ++i)
705     createTable(i->second, tablesCreated, &sout, false);
706 
707   for (ClassRegistry::iterator i = classRegistry_.begin();
708        i != classRegistry_.end(); ++i)
709     createRelations(i->second, tablesCreated, &sout);
710 
711   t.commit();
712 
713   return sout.str();
714 }
715 
createTables()716 void Session::createTables()
717 {
718   initSchema();
719 
720   Transaction t(*this);
721 
722   std::set<std::string> tablesCreated;
723 
724   for (ClassRegistry::iterator i = classRegistry_.begin();
725        i != classRegistry_.end(); ++i)
726     createTable(i->second, tablesCreated, nullptr, false);
727 
728   for (ClassRegistry::iterator i = classRegistry_.begin();
729        i != classRegistry_.end(); ++i)
730     createRelations(i->second, tablesCreated, nullptr);
731 
732   t.commit();
733 }
734 
createTable(Impl::MappingInfo * mapping,std::set<std::string> & tablesCreated,std::ostream * sout,bool createConstraints)735 void Session::createTable(Impl::MappingInfo *mapping,
736 			  std::set<std::string>& tablesCreated,
737                           std::ostream *sout,
738                           bool createConstraints)
739 {
740   if (tablesCreated.count(mapping->tableName) != 0)
741     return;
742 
743   tablesCreated.insert(mapping->tableName);
744 
745   std::stringstream sql;
746 
747   sql << "create table \"" << Impl::quoteSchemaDot(mapping->tableName)
748       << "\" (\n";
749 
750   bool firstField = true;
751 
752   // Auto-generated id
753   if (mapping->surrogateIdFieldName) {
754     sql << "  \"" << mapping->surrogateIdFieldName << "\" "
755 	<< connection(false)->autoincrementType()
756 	<< " primary key "
757 	<< connection(false)->autoincrementSql() << "";
758     firstField = false;
759   }
760 
761   // Optimistic locking version field
762   if (mapping->versionFieldName) {
763     if (!firstField)
764       sql << ",\n";
765 
766     sql << "  \"" << mapping->versionFieldName << "\" "
767 	<< sql_value_traits<int>::type(0, 0);
768 
769     firstField = false;
770   }
771 
772   std::string primaryKey;
773   for (unsigned i = 0; i < mapping->fields.size(); ++i) {
774     const FieldInfo& field = mapping->fields[i];
775 
776     if (!field.isVersionField()) {
777       if (!firstField)
778 	sql << ",\n";
779 
780       std::string sqlType = field.sqlType();
781       if (field.isForeignKey() && !(field.fkConstraints() & Impl::FKNotNull)) {
782 	if (sqlType.length() > 9
783 	    && sqlType.substr(sqlType.length() - 9) == " not null")
784 	  sqlType = sqlType.substr(0, sqlType.length() - 9);
785       }
786 
787       sql << "  \"" << field.name() << "\" " << sqlType;
788 
789       firstField = false;
790 
791       if (field.isNaturalIdField()) {
792 	if (!primaryKey.empty())
793 	  primaryKey += ", ";
794 	primaryKey += "\"" + field.name() + "\"";
795       }
796     }
797   }
798 
799   if (!primaryKey.empty()) {
800     if (!firstField)
801       sql << ",\n";
802 
803     sql << "  primary key (" << primaryKey << ")";
804   }
805 
806   for (unsigned i = 0; i < mapping->fields.size();) {
807     const FieldInfo& field = mapping->fields[i];
808 
809     if (field.isForeignKey() &&
810 	(createConstraints || !connection(false)->supportAlterTable())) {
811       if (!firstField)
812 	sql << ",\n";
813 
814       unsigned firstI = i;
815       i = findLastForeignKeyField(mapping, field, firstI);
816       sql << "  " << constraintString(mapping, field, firstI, i);
817 
818       createTable(mapping, tablesCreated, sout, false);
819     } else
820       ++i;
821   }
822 
823   sql << "\n)";
824 
825   executeSql(sql, sout);
826 
827   if (mapping->surrogateIdFieldName) {
828     std::string tableName = Impl::quoteSchemaDot(mapping->tableName);
829     std::string idFieldName = mapping->surrogateIdFieldName;
830 
831     std::vector<std::string> sql =
832       connection(false)->autoincrementCreateSequenceSql(tableName,
833 							idFieldName);
834 
835     executeSql(sql, sout);
836   }
837 }
838 
createRelations(Impl::MappingInfo * mapping,std::set<std::string> & tablesCreated,std::ostream * sout)839 void Session::createRelations(Impl::MappingInfo *mapping,
840 			      std::set<std::string>& tablesCreated,
841 			      std::ostream *sout)
842 {
843   for (unsigned i = 0; i < mapping->sets.size(); ++i) {
844     const Impl::SetInfo& set = mapping->sets[i];
845 
846     if (set.type == ManyToMany) {
847       if (tablesCreated.count(set.joinName) == 0) {
848 	Impl::MappingInfo *other = getMapping(set.tableName);
849 
850 	createJoinTable(set.joinName, mapping, other,
851 			set.joinSelfId, set.joinOtherId,
852 			set.fkConstraints, set.otherFkConstraints,
853 			set.flags & Impl::SetInfo::LiteralSelfId,
854 			set.flags & Impl::SetInfo::LiteralOtherId,
855 			tablesCreated, sout);
856       }
857     }
858   }
859 
860   if (connection(false)->supportAlterTable()){ //backend condition
861     for (unsigned i = 0; i < mapping->fields.size();) {
862       const FieldInfo& field = mapping->fields[i];
863       if (field.isForeignKey()){
864         std::stringstream sql;
865 
866 	std::string table = Impl::quoteSchemaDot(mapping->tableName);
867 
868         sql << "alter table \"" << table << "\""
869             << " add ";
870 
871         unsigned firstI = i;
872         i = findLastForeignKeyField(mapping, field, firstI);
873         sql << constraintString(mapping, field, firstI, i);
874 
875         executeSql(sql, sout);
876 
877       } else
878         ++i;
879     }
880   }
881 }
882 
883 //constraint fk_... foreign key ( ..., .. , .. ) references (..)
constraintString(Impl::MappingInfo * mapping,const FieldInfo & field,unsigned fromIndex,unsigned toIndex)884 std::string Session::constraintString(Impl::MappingInfo *mapping,
885                                       const FieldInfo& field,
886                                       unsigned fromIndex,
887                                       unsigned toIndex)
888 {
889   std::stringstream sql;
890 
891   sql << "constraint \"fk_"
892       << mapping->tableName << "_" << field.foreignKeyName() << "\""
893       << " foreign key (\"" << field.name() << "\"";
894 
895   for(unsigned i = fromIndex + 1; i < toIndex; ++i){
896     const FieldInfo& nextField = mapping->fields[i];
897     sql << ", \"" << nextField.name() << "\"";
898   }
899 
900   Impl::MappingInfo *otherMapping = getMapping(field.foreignKeyTable().c_str());
901 
902   sql << ") references \"" << Impl::quoteSchemaDot(field.foreignKeyTable())
903       << "\" (" << otherMapping->primaryKeys() << ")";
904 
905   if (field.fkConstraints() & Impl::FKOnUpdateCascade
906       && haveSupportUpdateCascade_)
907     sql << " on update cascade";
908   else if (field.fkConstraints() & Impl::FKOnUpdateSetNull
909 	   && haveSupportUpdateCascade_)
910     sql << " on update set null";
911   else if (field.fkConstraints() & Impl::FKOnUpdateRestrict
912            && haveSupportUpdateCascade_)
913     sql << " on update restrict";
914 
915   if (field.fkConstraints() & Impl::FKOnDeleteCascade)
916     sql << " on delete cascade";
917   else if (field.fkConstraints() & Impl::FKOnDeleteSetNull)
918     sql << " on delete set null";
919   else if (field.fkConstraints() & Impl::FKOnDeleteRestrict)
920     sql << " on delete restrict";
921 
922   if (connection(false)->supportDeferrableFKConstraint()) //backend condition
923     sql << " deferrable initially deferred";
924 
925   return sql.str();
926 }
927 
findLastForeignKeyField(Impl::MappingInfo * mapping,const FieldInfo & field,unsigned index)928 unsigned Session::findLastForeignKeyField(Impl::MappingInfo *mapping,
929                                  const FieldInfo& field,
930                                  unsigned index)
931 {
932   while (index < mapping->fields.size()) {
933     const FieldInfo& nextField = mapping->fields[index];
934     if (nextField.foreignKeyName() == field.foreignKeyName()) {
935       ++index;
936     } else
937       break;
938   }
939 
940   return index;
941 }
942 
createJoinTable(const std::string & joinName,Impl::MappingInfo * mapping1,Impl::MappingInfo * mapping2,const std::string & joinId1,const std::string & joinId2,int fkConstraints1,int fkConstraints2,bool literalJoinId1,bool literalJoinId2,std::set<std::string> & tablesCreated,std::ostream * sout)943 void Session::createJoinTable(const std::string& joinName,
944 			      Impl::MappingInfo *mapping1,
945 			      Impl::MappingInfo *mapping2,
946 			      const std::string& joinId1,
947 			      const std::string& joinId2,
948 			      int fkConstraints1, int fkConstraints2,
949 			      bool literalJoinId1, bool literalJoinId2,
950 			      std::set<std::string>& tablesCreated,
951 			      std::ostream *sout)
952 {
953   Impl::MappingInfo joinTableMapping;
954 
955   joinTableMapping.tableName = joinName.c_str();
956   joinTableMapping.versionFieldName = nullptr;
957   joinTableMapping.surrogateIdFieldName = nullptr;
958 
959   addJoinTableFields(joinTableMapping, mapping1, joinId1, "key1",
960 		     fkConstraints1, literalJoinId1);
961   addJoinTableFields(joinTableMapping, mapping2, joinId2, "key2",
962 		     fkConstraints2, literalJoinId2);
963 
964   createTable(&joinTableMapping, tablesCreated, sout, true);
965 
966   createJoinIndex(joinTableMapping, mapping1, joinId1, "key1", sout);
967   createJoinIndex(joinTableMapping, mapping2, joinId2, "key2", sout);
968 }
969 
createJoinIndex(Impl::MappingInfo & joinTableMapping,Impl::MappingInfo * mapping,const std::string & joinId,const std::string & foreignKeyName,std::ostream * sout)970 void Session::createJoinIndex(Impl::MappingInfo& joinTableMapping,
971 			      Impl::MappingInfo *mapping,
972 			      const std::string& joinId,
973 			      const std::string& foreignKeyName,
974 			      std::ostream *sout)
975 {
976   std::stringstream sql;
977 
978   sql << "create index \"" << joinTableMapping.tableName << "_"
979       << mapping->tableName;
980 
981   if (!joinId.empty())
982     sql << "_" << joinId;
983 
984   sql << "\" on \"" << Impl::quoteSchemaDot(joinTableMapping.tableName)
985       << "\" (";
986 
987   bool firstField = true;
988   for (unsigned int i = 0; i < joinTableMapping.fields.size(); ++i) {
989     const FieldInfo& f = joinTableMapping.fields[i];
990     if (f.foreignKeyName() == foreignKeyName) {
991       if (!firstField)
992 	sql << ", ";
993       firstField = false;
994 
995       sql << "\"" << f.name() << "\"";
996     }
997   }
998 
999   sql << ")";
1000 
1001   executeSql(sql, sout);
1002 }
1003 
1004 std::vector<Session::JoinId>
getJoinIds(Impl::MappingInfo * mapping,const std::string & joinId,bool literalJoinId)1005 Session::getJoinIds(Impl::MappingInfo *mapping, const std::string& joinId, bool literalJoinId)
1006 {
1007   std::vector<Session::JoinId> result;
1008 
1009   std::string foreignKeyName;
1010   if (joinId.empty())
1011     foreignKeyName = std::string(mapping->tableName);
1012   else
1013     foreignKeyName = joinId;
1014 
1015   if (mapping->surrogateIdFieldName) {
1016     std::string idName;
1017 
1018     if (literalJoinId)
1019       idName = joinId;
1020     else
1021       idName = foreignKeyName
1022 	+ "_" + mapping->surrogateIdFieldName;
1023 
1024     result.push_back
1025       (JoinId(idName, mapping->surrogateIdFieldName, longlongType_));
1026 
1027   } else {
1028     int nbNaturalIdFields = 0;
1029     for (unsigned i = 0; i < mapping->fields.size(); ++i) {
1030       const FieldInfo& field = mapping->fields[i];
1031 
1032       if (field.isNaturalIdField()) {
1033 	++nbNaturalIdFields;
1034 	std::string idName;
1035 	if (literalJoinId) {
1036 	  // NOTE: there should be only one natural id field in this case!
1037 	  idName = joinId;
1038 	} else {
1039 	  idName = foreignKeyName + "_" + field.name();
1040 	}
1041 	result.push_back(JoinId(idName, field.name(), field.sqlType()));
1042       }
1043     }
1044     if (literalJoinId && nbNaturalIdFields != 1) {
1045       throw Exception(std::string("The literal join id >") + joinId + " was used,"
1046                       " but there are " + std::to_string(nbNaturalIdFields) +
1047 		      " natural id fields. There may only be one natural id field.");
1048     }
1049   }
1050 
1051   return result;
1052 }
1053 
addJoinTableFields(Impl::MappingInfo & result,Impl::MappingInfo * mapping,const std::string & joinId,const std::string & keyName,int fkConstraints,bool literalJoinId)1054 void Session::addJoinTableFields(Impl::MappingInfo& result,
1055 				 Impl::MappingInfo *mapping,
1056 				 const std::string& joinId,
1057 				 const std::string& keyName,
1058 				 int fkConstraints,
1059 				 bool literalJoinId)
1060 {
1061   std::vector<JoinId> joinIds = getJoinIds(mapping, joinId, literalJoinId);
1062 
1063   for (unsigned i = 0; i < joinIds.size(); ++i)
1064     result.fields.push_back
1065       (FieldInfo(joinIds[i].joinIdName, &typeid(long long),
1066 		 joinIds[i].sqlType,
1067 		 mapping->tableName, keyName,
1068 		 FieldFlags::NaturalId | FieldFlags::ForeignKey,
1069 		 fkConstraints));
1070 }
1071 
dropTables()1072 void Session::dropTables()
1073 {
1074   initSchema();
1075 
1076   if (transaction_) {
1077     flush();
1078   }
1079 
1080   if (connectionPool_) {
1081     connectionPool_->prepareForDropTables();
1082     if (transaction_) {
1083       transaction_->connection_->prepareForDropTables();
1084     }
1085   } else if (connection_) {
1086     connection_->prepareForDropTables();
1087   } else if (transaction_) {
1088     transaction_->connection_->prepareForDropTables();
1089   }
1090 
1091   Transaction t(*this);
1092 
1093   flush();
1094 
1095   //remove constraints first.
1096   if (connection(false)->supportAlterTable()){
1097     for (ClassRegistry::iterator i = classRegistry_.begin();
1098          i != classRegistry_.end(); ++i){
1099       Impl::MappingInfo *mapping = i->second;
1100       //find the constraint.
1101       //ALTER TABLE products DROP CONSTRAINT some_name
1102       for (unsigned j = 0; j < mapping->fields.size(); ++j) {
1103         const FieldInfo& field = mapping->fields[j];
1104         if (field.isForeignKey()){
1105           std::stringstream sql;
1106 	  std::string table = Impl::quoteSchemaDot(mapping->tableName);
1107 
1108           sql << "alter table \"" << table << "\""
1109               << " drop "
1110               << connection(false)->alterTableConstraintString() << " "
1111               << constraintName(mapping->tableName, field.foreignKeyName());
1112 
1113           j = findLastForeignKeyField(mapping, field, j);
1114 
1115 	  executeSql(sql, nullptr);
1116         }
1117       }
1118     }
1119   }
1120 
1121   std::set<std::string> tablesDropped;
1122   for (ClassRegistry::iterator i = classRegistry_.begin();
1123        i != classRegistry_.end(); ++i)
1124     i->second->dropTable(*this, tablesDropped);
1125 
1126   t.commit();
1127 }
1128 
getMapping(const char * tableName)1129 Impl::MappingInfo *Session::getMapping(const char *tableName) const
1130 {
1131   TableRegistry::const_iterator i = tableRegistry_.find(tableName);
1132 
1133   if (i != tableRegistry_.end())
1134     return i->second;
1135   else
1136     return nullptr;
1137 }
1138 
needsFlush(MetaDboBase * obj)1139 void Session::needsFlush(MetaDboBase *obj)
1140 {
1141   typedef Impl::MetaDboBaseSet::nth_index<1>::type Set;
1142   Set& setIndex = dirtyObjects_->get<1>();
1143 
1144   std::pair<Set::iterator, bool> inserted = setIndex.insert(obj);
1145 
1146   if (inserted.second) {
1147     // was a new entry
1148     obj->incRef();
1149   }
1150 
1151   // If it's a delete, move it to the back
1152   //
1153   // In fact, this might be wrong: we need to consider dependencies
1154   // (constraints) that depend on this object: foreign keys generated
1155   // by 'belongsTo()' referencing this object: the objects that hold
1156   // these foreign keys may need to be updated (or deleted!) before
1157   // this object is deleted, one thus needs to take care of the order in which
1158   // objects are being deleted
1159   if (obj->isDeleted()) {
1160     // was an existing entry, move to back
1161     typedef Impl::MetaDboBaseSet::nth_index<0>::type List;
1162     List& listIndex = dirtyObjects_->get<0>();
1163 
1164     List::iterator i = dirtyObjects_->project<0>(inserted.first);
1165 
1166     listIndex.splice(listIndex.end(), listIndex, i);
1167   }
1168 }
1169 
flush()1170 void Session::flush()
1171 {
1172   for (unsigned i=0; i < objectsToAdd_.size(); i++)
1173     needsFlush(objectsToAdd_[i]);
1174 
1175   objectsToAdd_.clear();
1176 
1177   while (!dirtyObjects_->empty()) {
1178     Impl::MetaDboBaseSet::iterator i = dirtyObjects_->begin();
1179     MetaDboBase *dbo = *i;
1180     dbo->flush();
1181     dirtyObjects_->erase(i);
1182     dbo->decRef();
1183   }
1184 }
1185 
rereadAll(const char * tableName)1186 void Session::rereadAll(const char *tableName)
1187 {
1188   for (ClassRegistry::iterator i = classRegistry_.begin();
1189        i != classRegistry_.end(); ++i)
1190     if (!tableName || std::string(tableName) == i->second->tableName)
1191       i->second->rereadAll();
1192 }
1193 
discardUnflushed()1194 void Session::discardUnflushed()
1195 {
1196   objectsToAdd_.clear();
1197   rereadAll();
1198 }
1199 
statementId(const char * tableName,int statementIdx)1200 std::string Session::statementId(const char *tableName, int statementIdx)
1201 {
1202   return std::string(tableName) + ":" + std::to_string(statementIdx);
1203 }
1204 
getStatement(const std::string & id)1205 SqlStatement *Session::getStatement(const std::string& id)
1206 {
1207   return connection(true)->getStatement(id);
1208 }
1209 
getOrPrepareStatement(const std::string & sql)1210 SqlStatement *Session::getOrPrepareStatement(const std::string& sql)
1211 {
1212   SqlStatement *s = getStatement(sql);
1213 
1214   if (!s)
1215     s = prepareStatement(sql, sql);
1216 
1217   return s;
1218 }
1219 
getStatement(const char * tableName,int statementIdx)1220 SqlStatement *Session::getStatement(const char *tableName, int statementIdx)
1221 {
1222   std::string id = statementId(tableName, statementIdx);
1223   SqlStatement *result = getStatement(id);
1224 
1225   if (!result)
1226     result = prepareStatement(id, getStatementSql(tableName, statementIdx));
1227 
1228   return result;
1229 }
1230 
1231 const std::string&
getStatementSql(const char * tableName,int statementIdx)1232 Session::getStatementSql(const char *tableName, int statementIdx)
1233 {
1234   return getMapping(tableName)->statements[statementIdx];
1235 }
1236 
prepareStatement(const std::string & id,const std::string & sql)1237 SqlStatement *Session::prepareStatement(const std::string& id,
1238 					const std::string& sql)
1239 {
1240   SqlConnection *conn = connection(false);
1241   std::unique_ptr<SqlStatement> stmt = conn->prepareStatement(sql);
1242   SqlStatement *result = stmt.get();
1243   conn->saveStatement(id, std::move(stmt));
1244   result->use();
1245 
1246   return result;
1247 }
1248 
getFields(const char * tableName,std::vector<FieldInfo> & result)1249 void Session::getFields(const char *tableName,
1250 			std::vector<FieldInfo>& result)
1251 {
1252   initSchema();
1253 
1254   Impl::MappingInfo *mapping = getMapping(tableName);
1255   if (!mapping)
1256     throw Exception(std::string("Table ") + tableName + " was not mapped.");
1257 
1258   if (mapping->surrogateIdFieldName)
1259     result.push_back(FieldInfo(mapping->surrogateIdFieldName,
1260 			       &typeid(long long),
1261 			       longlongType_,
1262 			       FieldFlags::SurrogateId |
1263 			       FieldFlags::NeedsQuotes));
1264 
1265   if (mapping->versionFieldName)
1266     result.push_back(FieldInfo(mapping->versionFieldName, &typeid(int),
1267 			       intType_,
1268 			       FieldFlags::Version | FieldFlags::NeedsQuotes));
1269 
1270   result.insert(result.end(), mapping->fields.begin(), mapping->fields.end());
1271 }
1272 
createDbo(Impl::MappingInfo * mapping)1273 MetaDboBase *Session::createDbo(Impl::MappingInfo *mapping)
1274 {
1275   return mapping->create(*this);
1276 }
1277 
load(MetaDboBase * dbo)1278 void Session::load(MetaDboBase *dbo)
1279 {
1280   Impl::MappingInfo *mapping = dbo->getMapping();
1281   mapping->load(*this, dbo);
1282 }
1283 
1284   }
1285 }
1286