1 #ifndef ONCE_FPARSERAD_H_
2 #define ONCE_FPARSERAD_H_
3
4 #include "fparser.hh"
5 #include <exception>
6 #include <iostream>
7 #include <fstream>
8
9 template<typename Value_t>
10 class ADImplementation;
11
12 template<typename Value_t>
13 class FunctionParserADBase : public FunctionParserBase<Value_t>
14 {
15 public:
16 FunctionParserADBase();
17 FunctionParserADBase(const FunctionParserADBase& cpy);
18 virtual ~FunctionParserADBase();
19
20 /**
21 * This class manages its own memory, so the compiler-generated copy
22 * assignment, move assignment, and move constructor implementations
23 * are not safe to use. We therefore explicitly delete them so they
24 * can't be called accidentally.
25 */
26 FunctionParserADBase (FunctionParserADBase &&) = delete;
27 FunctionParserADBase & operator= (const FunctionParserADBase &) = delete;
28 FunctionParserADBase & operator= (FunctionParserADBase &&) = delete;
29
30 /**
31 * auto-differentiate for var
32 */
33 int AutoDiff(const std::string & var_name);
34
35 /**
36 * add another variable
37 */
38 bool AddVariable(const std::string & var_name);
39
40 /**
41 * check if the function is equal to 0
42 * This is a common case for vanishing derivatives. This relies on the
43 * function to be optimized.
44 */
45 bool isZero();
46
47 /**
48 * check if the function's byte code is empty.
49 */
isEmpty()50 bool isEmpty() { return this->mData->mByteCode.empty(); }
51
52 /**
53 * set the bytecode of this function to return constant zero.
54 * this provides a well defined state in case AutoDiff fails
55 */
56 void setZero();
57
58 // feature flags for this parser
59 enum ADFlags {
60 /**
61 * In certain applications derivatives are built proactively and may never be used.
62 * We silence all AutoDiff exceptions in that case to avoid confusing the user.
63 */
64 ADSilenceErrors = 1,
65 /**
66 * Immediately apply the optimizer grammars to the derivative tree structure.
67 * This saves a round trip to byte code.
68 */
69 ADAutoOptimize = 2,
70 /**
71 * Use cached JIT compiled function files, This bypasses the compilation stage
72 * for all further runs, after the JIT compilation ran successfully at least once.
73 */
74 ADJITCache = 4,
75 /**
76 * Use cached bytecode for the derivatives. This bypasses the automatic differentiation,
77 * (optimization), and byte code synthesis.
78 */
79 ADCacheDerivatives = 8
80 };
81
82 /**
83 * Set the feature flags for this parser (this way gives us better control over default values)
84 */
SetADFlags(int flags,bool turnon=true)85 void SetADFlags(int flags, bool turnon = true) {
86 if (turnon)
87 mADFlags |= flags;
88 else
89 mADFlags &= ~flags;
90 }
UnsetADFlags(int flags)91 void UnsetADFlags(int flags) { mADFlags &= ~flags; }
ClearADFlags()92 void ClearADFlags() { mADFlags = 0; }
93
94 /**
95 * compile the current function, or load a previously compiled copy.
96 * Warning: When re-using an FParser function object by parsing a new expression
97 * the previously JIT compiled function will continue to be Evaled until the
98 * JITCompile method is called again.
99 */
100 bool JITCompile();
101
102 /**
103 * wrap Optimize of the parent class to check for a JIT compiled version and redo
104 * the compilation after Optimization
105 */
106 void Optimize();
107
108 /**
109 * write the full state of the current FParser object to a stream
110 */
111 void Serialize(std::ostream &);
112
113 /**
114 * restore the full state of the current FParser object from a stream
115 */
116 void Unserialize(std::istream &);
117
118 #if LIBMESH_HAVE_FPARSER_JIT
119 /**
120 * Overwrite the Exec function with one that tests for a JIT compiled version
121 * and uses that if it exists
122 */
123 Value_t Eval(const Value_t* Vars);
124 #endif
125
126 /**
127 * look up the opcode number for a given variable name
128 * throws UnknownVariableException if the variable is not found
129 */
130 unsigned int LookUpVarOpcode(const std::string & var_name);
131
132 /**
133 * register a dependency between variables so that da/db = c
134 */
135 void RegisterDerivative(const std::string & a, const std::string & b, const std::string & c);
136
137 protected:
138
139 #if LIBMESH_HAVE_FPARSER_JIT
140 /// return a SHA1 hash for the current bytecode and value type name
141 std::string JITCodeHash(const std::string & value_type_name);
142
143 /// write generated C++ code to stream
144 bool JITCodeGen(std::ostream & ccout, const std::string & fname, const std::string & Value_t_name);
145
146 /// helper function to perform the JIT compilation (needs the Value_t typename as a string)
147 bool JITCompileHelper(const std::string & Value_t_name,
148 const std::string & extra_options = "",
149 const std::string & extra_headers = "");
150 #endif // LIBMESH_HAVE_FPARSER_JIT
151
152 /// function pointer type alias. This permits a Real Value_t function to be compiled
153 /// to support dual numbers
154 template <typename ActualValue_t>
155 using CompiledFunctionPtr = void (*)(ActualValue_t *, const ActualValue_t *,
156 const Value_t *, const Value_t);
157
158 /// update pointer to immediate data
159 void updatePImmed();
160
161 /// clear the runtime evaluation error flag
clearEvalError()162 void clearEvalError() { this->mData->mEvalErrorType = 0; }
163
164 /// JIT function pointer
165 void *compiledFunction;
166
167 /// pointer to the mImmed values (or NULL if the mImmed vector is empty)
168 Value_t * pImmed;
169
170 // user function plog
171 static Value_t fp_plog(const Value_t * params);
172
173 // user function erf
174 static Value_t fp_erf(const Value_t * params);
175
176 // function ID for the plog function
177 unsigned int mFPlog;
178
179 // function ID for the erf function
180 unsigned int mFErf;
181
182 // flags that control cache bahavior, optimization, and error reporting
183 int mADFlags;
184
185 // registered derivative table, and entry structure
186 struct VariableDerivative {
187 unsigned int var, dependence, derivative;
188 };
189 std::vector<VariableDerivative> mRegisteredDerivatives;
190
191 // private implementaion of the automatic differentiation algorithm
192 ADImplementation<Value_t> * ad;
193
194 // the firewalled implementation class of the AD algorithm has full access to the FParser object
195 friend class ADImplementation<Value_t>;
196
197 // Exceptions
198 class UnknownVariable : public std::exception {
what() const199 virtual const char* what() const throw() override { return "Unknown variable"; }
200 } UnknownVariableException;
201 class UnknownSerializationVersion : public std::exception {
what() const202 virtual const char* what() const throw() override { return "Unknown serialization file version"; }
203 } UnknownSerializationVersionException;
204 };
205
206 #ifdef LIBMESH_HAVE_FPARSER_JIT
207
208 /// Forward declare SHA1 hash object
209 class SHA1;
210
211 /// Namespacing the utility classes (rather than nesting them in a templated class)
212 namespace FParserJIT
213 {
214 /// Simplified C++ interface to lib SHA1
215 class Hash
216 {
217 public:
218 Hash();
219 ~Hash();
220
221 template <typename T>
222 void addData(const T & v);
223
224 std::string get();
225
226 protected:
227 /// the actual lib SHA1 call is in the helper so that we don't need to make lib/sha1.h available
228 void addDataHelper(const char * start, std::size_t size);
229
230 SHA1 * _sha1;
231 };
232
233 template <typename T>
addData(const T & v)234 void Hash::addData(const T & v)
235 {
236 auto start = v.data();
237 std::size_t size = v.size() * sizeof(*start);
238 if (size > 0)
239 addDataHelper(reinterpret_cast<const char *>(start), size);
240 }
241
242 /// Handle compilation, caching, and temporary files
243 class Compiler
244 {
245 public:
246 Compiler(const std::string & master_hash = "");
247 ~Compiler();
248 std::ostream & source();
249
250 bool probeCache();
251 bool run(const std::string & compiler_options = "");
252 void * getFunction(const std::string & fname);
253
254 protected:
255 std::ofstream _ccout;
256 void * _lib;
257 const std::string _jitdir;
258 std::string _ccname;
259 std::string _objectname;
260 std::string _object_so;
261 bool _success;
262
263 const std::string _master_hash;
264 const bool _use_cache;
265 };
266 } // namespace FParserJIT
267
268 #endif // LIBMESH_HAVE_FPARSER_JIT
269
270 class FunctionParserAD: public FunctionParserADBase<double> {};
271 class FunctionParserAD_f: public FunctionParserADBase<float> {};
272
273 #endif //ONCE_FPARSERAD_H_
274