1 //
2 //  Copyright (C) 2020 Shrey Aryan
3 //
4 //   @@ All Rights Reserved @@
5 //  This file is part of the RDKit.
6 //  The contents are covered by the terms of the BSD license
7 //  which is included in the file license.txt, found at the root
8 //  of the RDKit source tree.
9 //
10 #ifndef GENERAL_FILE_READER_H
11 #define GENERAL_FILE_READER_H
12 #include <RDGeneral/BadFileException.h>
13 #include <RDStreams/streams.h>
14 
15 #include <boost/algorithm/string.hpp>
16 #include <iostream>
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 #include "MolSupplier.h"
22 #include "MultithreadedSDMolSupplier.h"
23 #include "MultithreadedSmilesMolSupplier.h"
24 
25 namespace RDKit {
26 namespace GeneralMolSupplier {
27 struct SupplierOptions {
28   bool takeOwnership = true;
29   bool sanitize = true;
30   bool removeHs = true;
31   bool strictParsing = true;
32 
33   std::string delimiter = "\t";
34   int smilesColumn = 0;
35   int nameColumn = 1;
36   bool titleLine = true;
37 
38   std::string nameRecord = "";
39   int confId2D = -1;
40   int confId3D = 0;
41 
42   unsigned int numWriterThreads = 0;
43 };
44 //! current supported file formats
45 const std::vector<std::string> supportedFileFormats{
46     "sdf", "mae", "maegz", "sdfgz", "smi", "csv", "txt", "tsv", "tdt"};
47 //! current supported compression formats
48 const std::vector<std::string> supportedCompressionFormats{"gz"};
49 
50 //! given file path determines the file and compression format
51 //! returns true on success, otherwise false
52 //! Note: Error handeling is done in the getSupplier method
53 
determineFormat(const std::string path,std::string & fileFormat,std::string & compressionFormat)54 void determineFormat(const std::string path, std::string& fileFormat,
55                      std::string& compressionFormat) {
56   //! filename without compression format
57   std::string basename;
58   //! Special case maegz.
59   //! NOTE: also supporting case-insensitive filesystems
60   if (boost::algorithm::iends_with(path, ".maegz")) {
61     fileFormat = "mae";
62     compressionFormat = "gz";
63     return;
64   } else if (boost::algorithm::iends_with(path, ".sdfgz")) {
65     fileFormat = "sdf";
66     compressionFormat = "gz";
67     return;
68   } else if (boost::algorithm::iends_with(path, ".gz")) {
69     compressionFormat = "gz";
70     basename = path.substr(0, path.size() - 3);
71   } else if (boost::algorithm::iends_with(path, ".zst") ||
72              boost::algorithm::iends_with(path, ".bz2") ||
73              boost::algorithm::iends_with(path, ".7z")) {
74     throw BadFileException(
75         "Unsupported compression extension (.zst, .bz2, .7z) given path: " +
76         path);
77   } else {
78     basename = path;
79     compressionFormat = "";
80   }
81   for (auto const& suffix : supportedFileFormats) {
82     if (boost::algorithm::iends_with(basename, "." + suffix)) {
83       fileFormat = suffix;
84       return;
85     }
86   }
87   throw BadFileException(
88       "Unsupported structure or compression extension given path: " + path);
89 }
90 
91 //! returns a new MolSupplier object based on the file name instantiated
92 //! with the relevant options provided in the SupplierOptions struct
93 /*!
94     <b>Note:</b>
95       - the caller owns the memory and therefore the pointer must be deleted
96 */
97 
getSupplier(const std::string & path,const struct SupplierOptions & opt)98 std::unique_ptr<MolSupplier> getSupplier(const std::string& path,
99                                          const struct SupplierOptions& opt) {
100   std::string fileFormat = "";
101   std::string compressionFormat = "";
102   //! get the file and compression format form the path
103   determineFormat(path, fileFormat, compressionFormat);
104 
105   std::istream* strm;
106   if (compressionFormat.empty()) {
107     strm = new std::ifstream(path.c_str(), std::ios::in | std::ios::binary);
108   } else {
109     strm = new gzstream(path);
110   }
111 
112   //! Dispatch to the appropriate supplier
113   if (fileFormat == "sdf") {
114 #ifdef RDK_THREADSAFE_SSS
115     if (opt.numWriterThreads > 0) {
116       MultithreadedSDMolSupplier* sdsup = new MultithreadedSDMolSupplier(
117           strm, true, opt.sanitize, opt.removeHs, opt.strictParsing,
118           opt.numWriterThreads);
119       std::unique_ptr<MolSupplier> p(sdsup);
120       return p;
121     }
122 #endif
123     ForwardSDMolSupplier* sdsup = new ForwardSDMolSupplier(
124         strm, true, opt.sanitize, opt.removeHs, opt.strictParsing);
125     std::unique_ptr<MolSupplier> p(sdsup);
126     return p;
127   }
128 
129   else if (fileFormat == "smi" || fileFormat == "csv" || fileFormat == "txt" ||
130            fileFormat == "tsv") {
131 #ifdef RDK_THREADSAFE_SSS
132     if (opt.numWriterThreads > 0) {
133       MultithreadedSmilesMolSupplier* smsup =
134           new MultithreadedSmilesMolSupplier(
135               strm, true, opt.delimiter, opt.smilesColumn, opt.nameColumn,
136               opt.titleLine, opt.sanitize, opt.numWriterThreads);
137       std::unique_ptr<MolSupplier> p(smsup);
138       return p;
139     }
140 #endif
141     SmilesMolSupplier* smsup =
142         new SmilesMolSupplier(strm, true, opt.delimiter, opt.smilesColumn,
143                               opt.nameColumn, opt.titleLine, opt.sanitize);
144     std::unique_ptr<MolSupplier> p(smsup);
145     return p;
146   }
147 #ifdef RDK_BUILD_MAEPARSER_SUPPORT
148   else if (fileFormat == "mae") {
149     MaeMolSupplier* maesup =
150         new MaeMolSupplier(strm, true, opt.sanitize, opt.removeHs);
151     std::unique_ptr<MolSupplier> p(maesup);
152     return p;
153   }
154 #endif
155   else if (fileFormat == "tdt") {
156     TDTMolSupplier* tdtsup = new TDTMolSupplier(
157         strm, true, opt.nameRecord, opt.confId2D, opt.confId3D, opt.sanitize);
158     std::unique_ptr<MolSupplier> p(tdtsup);
159     return p;
160   }
161   throw BadFileException("Unsupported fileFormat: " + fileFormat);
162 }
163 
164 }  // namespace GeneralMolSupplier
165 }  // namespace RDKit
166 #endif
167