1 /** @file function.cpp
2 *
3 * Implementation of class of symbolic functions. */
4
5 /*
6 * GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7 *
8 * This program is free software; you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation; either version 2 of the License, or
11 * (at your option) any later version.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21 */
22
23 #define register
24 #define PY_SSIZE_T_CLEAN
25 #include <Python.h>
26 #include "py_funcs.h"
27 #include "function.h"
28 #include "operators.h"
29 #include "fderivative.h"
30 #include "ex.h"
31 #include "lst.h"
32 #include "print.h"
33 #include "power.h"
34 #include "relational.h"
35 #include "archive.h"
36 #include "inifcns.h"
37 #include "tostring.h"
38 #include "utils.h"
39 #include "remember.h"
40 #include "symbol.h"
41 #include "cmatcher.h"
42 #include "wildcard.h"
43 #include "expairseq.h"
44
45 #include <iostream>
46 #include <string>
47 #include <stdexcept>
48 #include <list>
49 #include <limits>
50 #ifdef DO_GINAC_ASSERT
51 # include <typeinfo>
52 #endif
53
54 namespace GiNaC {
55
56 //////////
57 // helper class function_options
58 //////////
59
function_options()60 function_options::function_options()
61 {
62 initialize();
63 }
64
function_options(std::string const & n,std::string const & tn)65 function_options::function_options(std::string const & n, std::string const & tn)
66 {
67 initialize();
68 set_name(n, tn);
69 }
70
function_options(std::string const & n,unsigned np)71 function_options::function_options(std::string const & n, unsigned np)
72 {
73 static std::string empty;
74 initialize();
75 set_name(n, empty);
76 nparams = np;
77 }
78
~function_options()79 function_options::~function_options()
80 {
81 // nothing to clean up at the moment
82 }
83
initialize()84 void function_options::initialize()
85 {
86 static std::string s1("unnamed_function"), s2("\\mbox{unnamed}");
87 set_name(s1, s2);
88 nparams = 0;
89 eval_f = real_part_f = imag_part_f = conjugate_f = derivative_f
90 = pynac_eval_f = expl_derivative_f = power_f = series_f
91 = subs_f = nullptr;
92 evalf_f = nullptr;
93 evalf_params_first = true;
94 apply_chain_rule = true;
95 use_return_type = false;
96 eval_use_exvector_args = false;
97 evalf_use_exvector_args = false;
98 conjugate_use_exvector_args = false;
99 real_part_use_exvector_args = false;
100 imag_part_use_exvector_args = false;
101 derivative_use_exvector_args = false;
102 expl_derivative_use_exvector_args = false;
103 power_use_exvector_args = false;
104 series_use_exvector_args = false;
105 print_use_exvector_args = false;
106 use_remember = false;
107 python_func = 0;
108 functions_with_same_name = 1;
109 symtree = 0;
110 }
111
set_name(std::string const & n,std::string const & tn)112 function_options & function_options::set_name(std::string const & n,
113 std::string const & tn)
114 {
115 name.assign(n);
116 if (tn.empty()) {
117 TeX_name.assign("{\\rm ");
118 TeX_name += n;
119 TeX_name.append("}");
120 }
121 else
122 TeX_name.assign(tn);
123 return *this;
124 }
125
latex_name(std::string const & tn)126 function_options & function_options::latex_name(std::string const & tn)
127 {
128 TeX_name = tn;
129 return *this;
130 }
131
132 // the following lines have been generated for max. 14 parameters
eval_func(eval_funcp_1 e)133 function_options & function_options::eval_func(eval_funcp_1 e)
134 {
135 test_and_set_nparams(1);
136 eval_f = eval_funcp(e);
137 pynac_eval_f = eval_f;
138 return *this;
139 }
eval_func(eval_funcp_2 e)140 function_options & function_options::eval_func(eval_funcp_2 e)
141 {
142 test_and_set_nparams(2);
143 eval_f = eval_funcp(e);
144 pynac_eval_f = eval_f;
145 return *this;
146 }
eval_func(eval_funcp_3 e)147 function_options & function_options::eval_func(eval_funcp_3 e)
148 {
149 test_and_set_nparams(3);
150 eval_f = eval_funcp(e);
151 pynac_eval_f = eval_f;
152 return *this;
153 }
eval_func(eval_funcp_6 e)154 function_options & function_options::eval_func(eval_funcp_6 e)
155 {
156 test_and_set_nparams(6);
157 eval_f = eval_funcp(e);
158 pynac_eval_f = eval_f;
159 return *this;
160 }
161
evalf_func(evalf_funcp_1 ef)162 function_options & function_options::evalf_func(evalf_funcp_1 ef)
163 {
164 test_and_set_nparams(1);
165 evalf_f = evalf_funcp(ef);
166 return *this;
167 }
evalf_func(evalf_funcp_2 ef)168 function_options & function_options::evalf_func(evalf_funcp_2 ef)
169 {
170 test_and_set_nparams(2);
171 evalf_f = evalf_funcp(ef);
172 return *this;
173 }
evalf_func(evalf_funcp_3 ef)174 function_options & function_options::evalf_func(evalf_funcp_3 ef)
175 {
176 test_and_set_nparams(3);
177 evalf_f = evalf_funcp(ef);
178 return *this;
179 }
evalf_func(evalf_funcp_6 ef)180 function_options & function_options::evalf_func(evalf_funcp_6 ef)
181 {
182 test_and_set_nparams(6);
183 evalf_f = evalf_funcp(ef);
184 return *this;
185 }
186
conjugate_func(conjugate_funcp_1 c)187 function_options & function_options::conjugate_func(conjugate_funcp_1 c)
188 {
189 test_and_set_nparams(1);
190 conjugate_f = conjugate_funcp(c);
191 return *this;
192 }
conjugate_func(conjugate_funcp_2 c)193 function_options & function_options::conjugate_func(conjugate_funcp_2 c)
194 {
195 test_and_set_nparams(2);
196 conjugate_f = conjugate_funcp(c);
197 return *this;
198 }
conjugate_func(conjugate_funcp_3 c)199 function_options & function_options::conjugate_func(conjugate_funcp_3 c)
200 {
201 test_and_set_nparams(3);
202 conjugate_f = conjugate_funcp(c);
203 return *this;
204 }
205
real_part_func(real_part_funcp_1 c)206 function_options & function_options::real_part_func(real_part_funcp_1 c)
207 {
208 test_and_set_nparams(1);
209 real_part_f = real_part_funcp(c);
210 return *this;
211 }
real_part_func(real_part_funcp_2 c)212 function_options & function_options::real_part_func(real_part_funcp_2 c)
213 {
214 test_and_set_nparams(2);
215 real_part_f = real_part_funcp(c);
216 return *this;
217 }
real_part_func(real_part_funcp_3 c)218 function_options & function_options::real_part_func(real_part_funcp_3 c)
219 {
220 test_and_set_nparams(3);
221 real_part_f = real_part_funcp(c);
222 return *this;
223 }
224
imag_part_func(imag_part_funcp_1 c)225 function_options & function_options::imag_part_func(imag_part_funcp_1 c)
226 {
227 test_and_set_nparams(1);
228 imag_part_f = imag_part_funcp(c);
229 return *this;
230 }
imag_part_func(imag_part_funcp_2 c)231 function_options & function_options::imag_part_func(imag_part_funcp_2 c)
232 {
233 test_and_set_nparams(2);
234 imag_part_f = imag_part_funcp(c);
235 return *this;
236 }
imag_part_func(imag_part_funcp_3 c)237 function_options & function_options::imag_part_func(imag_part_funcp_3 c)
238 {
239 test_and_set_nparams(3);
240 imag_part_f = imag_part_funcp(c);
241 return *this;
242 }
243
derivative_func(derivative_funcp_1 d)244 function_options & function_options::derivative_func(derivative_funcp_1 d)
245 {
246 test_and_set_nparams(1);
247 derivative_f = derivative_funcp(d);
248 return *this;
249 }
derivative_func(derivative_funcp_2 d)250 function_options & function_options::derivative_func(derivative_funcp_2 d)
251 {
252 test_and_set_nparams(2);
253 derivative_f = derivative_funcp(d);
254 return *this;
255 }
derivative_func(derivative_funcp_3 d)256 function_options & function_options::derivative_func(derivative_funcp_3 d)
257 {
258 test_and_set_nparams(3);
259 derivative_f = derivative_funcp(d);
260 return *this;
261 }
derivative_func(derivative_funcp_6 d)262 function_options & function_options::derivative_func(derivative_funcp_6 d)
263 {
264 test_and_set_nparams(6);
265 derivative_f = derivative_funcp(d);
266 return *this;
267 }
268
expl_derivative_func(expl_derivative_funcp_1 d)269 function_options & function_options::expl_derivative_func(expl_derivative_funcp_1 d)
270 {
271 test_and_set_nparams(1);
272 expl_derivative_f = expl_derivative_funcp(d);
273 return *this;
274 }
expl_derivative_func(expl_derivative_funcp_2 d)275 function_options & function_options::expl_derivative_func(expl_derivative_funcp_2 d)
276 {
277 test_and_set_nparams(2);
278 expl_derivative_f = expl_derivative_funcp(d);
279 return *this;
280 }
expl_derivative_func(expl_derivative_funcp_3 d)281 function_options & function_options::expl_derivative_func(expl_derivative_funcp_3 d)
282 {
283 test_and_set_nparams(3);
284 expl_derivative_f = expl_derivative_funcp(d);
285 return *this;
286 }
287
power_func(power_funcp_1 d)288 function_options & function_options::power_func(power_funcp_1 d)
289 {
290 test_and_set_nparams(1);
291 power_f = power_funcp(d);
292 return *this;
293 }
power_func(power_funcp_2 d)294 function_options & function_options::power_func(power_funcp_2 d)
295 {
296 test_and_set_nparams(2);
297 power_f = power_funcp(d);
298 return *this;
299 }
power_func(power_funcp_3 d)300 function_options & function_options::power_func(power_funcp_3 d)
301 {
302 test_and_set_nparams(3);
303 power_f = power_funcp(d);
304 return *this;
305 }
306
series_func(series_funcp_1 s)307 function_options & function_options::series_func(series_funcp_1 s)
308 {
309 test_and_set_nparams(1);
310 series_f = series_funcp(s);
311 return *this;
312 }
series_func(series_funcp_2 s)313 function_options & function_options::series_func(series_funcp_2 s)
314 {
315 test_and_set_nparams(2);
316 series_f = series_funcp(s);
317 return *this;
318 }
series_func(series_funcp_3 s)319 function_options & function_options::series_func(series_funcp_3 s)
320 {
321 test_and_set_nparams(3);
322 series_f = series_funcp(s);
323 return *this;
324 }
325
subs_func(subs_funcp_1 s)326 function_options & function_options::subs_func(subs_funcp_1 s)
327 {
328 test_and_set_nparams(1);
329 subs_f = subs_funcp(s);
330 return *this;
331 }
subs_func(subs_funcp_2 s)332 function_options & function_options::subs_func(subs_funcp_2 s)
333 {
334 test_and_set_nparams(2);
335 subs_f = subs_funcp(s);
336 return *this;
337 }
subs_func(subs_funcp_3 s)338 function_options & function_options::subs_func(subs_funcp_3 s)
339 {
340 test_and_set_nparams(3);
341 subs_f = subs_funcp(s);
342 return *this;
343 }
344
345 // end of generated lines
346
eval_func(eval_funcp_exvector e)347 function_options& function_options::eval_func(eval_funcp_exvector e)
348 {
349 eval_use_exvector_args = true;
350 eval_f = eval_funcp(e);
351 pynac_eval_f = eval_f;
352 return *this;
353 }
evalf_func(evalf_funcp_exvector ef)354 function_options& function_options::evalf_func(evalf_funcp_exvector ef)
355 {
356 evalf_use_exvector_args = true;
357 evalf_f = evalf_funcp(ef);
358 return *this;
359 }
conjugate_func(conjugate_funcp_exvector c)360 function_options& function_options::conjugate_func(conjugate_funcp_exvector c)
361 {
362 conjugate_use_exvector_args = true;
363 conjugate_f = conjugate_funcp(c);
364 return *this;
365 }
real_part_func(real_part_funcp_exvector c)366 function_options& function_options::real_part_func(real_part_funcp_exvector c)
367 {
368 real_part_use_exvector_args = true;
369 real_part_f = real_part_funcp(c);
370 return *this;
371 }
imag_part_func(imag_part_funcp_exvector c)372 function_options& function_options::imag_part_func(imag_part_funcp_exvector c)
373 {
374 imag_part_use_exvector_args = true;
375 imag_part_f = imag_part_funcp(c);
376 return *this;
377 }
378
derivative_func(derivative_funcp_exvector d)379 function_options& function_options::derivative_func(derivative_funcp_exvector d)
380 {
381 derivative_use_exvector_args = true;
382 derivative_f = derivative_funcp(d);
383 return *this;
384 }
power_func(power_funcp_exvector d)385 function_options& function_options::power_func(power_funcp_exvector d)
386 {
387 power_use_exvector_args = true;
388 power_f = power_funcp(d);
389 return *this;
390 }
series_func(series_funcp_exvector s)391 function_options& function_options::series_func(series_funcp_exvector s)
392 {
393 series_use_exvector_args = true;
394 series_f = series_funcp(s);
395 return *this;
396 }
397
derivative_func(derivative_funcp_exvector_symbol d)398 function_options& function_options::derivative_func(
399 derivative_funcp_exvector_symbol d)
400 {
401 derivative_use_exvector_args = true;
402 derivative_f = derivative_funcp(d);
403 return *this;
404 }
405
eval_func(PyObject * e)406 function_options& function_options::eval_func(PyObject* e)
407 {
408 python_func |= eval_python_f;
409 eval_f = eval_funcp(e);
410 return *this;
411 }
evalf_func(PyObject * ef)412 function_options& function_options::evalf_func(PyObject* ef)
413 {
414 python_func |= evalf_python_f;
415 evalf_f = evalf_funcp(ef);
416 return *this;
417 }
conjugate_func(PyObject * c)418 function_options& function_options::conjugate_func(PyObject* c)
419 {
420 python_func |= conjugate_python_f;
421 conjugate_f = conjugate_funcp(c);
422 return *this;
423 }
real_part_func(PyObject * c)424 function_options& function_options::real_part_func(PyObject* c)
425 {
426 python_func |= real_part_python_f;
427 real_part_f = real_part_funcp(c);
428 return *this;
429 }
imag_part_func(PyObject * c)430 function_options& function_options::imag_part_func(PyObject* c)
431 {
432 python_func |= imag_part_python_f;
433 imag_part_f = imag_part_funcp(c);
434 return *this;
435 }
436
derivative_func(PyObject * d)437 function_options& function_options::derivative_func(PyObject* d)
438 {
439 python_func |= derivative_python_f;
440 derivative_f = derivative_funcp(d);
441 return *this;
442 }
power_func(PyObject * d)443 function_options& function_options::power_func(PyObject* d)
444 {
445 python_func |= power_python_f;
446 power_f = power_funcp(d);
447 return *this;
448 }
series_func(PyObject * s)449 function_options& function_options::series_func(PyObject* s)
450 {
451 python_func |= series_python_f;
452 series_f = series_funcp(s);
453 return *this;
454 }
subs_func(PyObject * e)455 function_options& function_options::subs_func(PyObject* e)
456 {
457 python_func |= subs_python_f;
458 subs_f = subs_funcp(e);
459 return *this;
460 }
461
set_return_type(unsigned rt,tinfo_t rtt)462 function_options & function_options::set_return_type(unsigned rt, tinfo_t rtt)
463 {
464 use_return_type = true;
465 return_type = rt;
466 return_type_tinfo = rtt;
467 return *this;
468 }
469
do_not_evalf_params()470 function_options & function_options::do_not_evalf_params()
471 {
472 evalf_params_first = false;
473 return *this;
474 }
475
do_not_apply_chain_rule()476 function_options & function_options::do_not_apply_chain_rule()
477 {
478 apply_chain_rule = false;
479 return *this;
480 }
481
remember(unsigned size,unsigned assoc_size,unsigned strategy)482 function_options & function_options::remember(unsigned size,
483 unsigned assoc_size,
484 unsigned strategy)
485 {
486 use_remember = true;
487 remember_size = size;
488 remember_assoc_size = assoc_size;
489 remember_strategy = strategy;
490 return *this;
491 }
492
overloaded(unsigned o)493 function_options & function_options::overloaded(unsigned o)
494 {
495 functions_with_same_name = o;
496 return *this;
497 }
498
test_and_set_nparams(unsigned n)499 void function_options::test_and_set_nparams(unsigned n)
500 {
501 if (nparams==0) {
502 nparams = n;
503 } else if (nparams!=n) {
504 // we do not throw an exception here because this code is
505 // usually executed before main(), so the exception could not
506 // be caught anyhow
507 std::cerr << "WARNING: " << name << "(): number of parameters ("
508 << n << ") differs from number set before ("
509 << nparams << ")" << std::endl;
510 }
511 }
512
set_print_func(unsigned id,print_funcp f)513 void function_options::set_print_func(unsigned id, print_funcp f)
514 {
515 if (id >= print_dispatch_table.size())
516 print_dispatch_table.resize(id + 1);
517 print_dispatch_table[id] = f;
518 }
519
set_print_latex_func(PyObject * f)520 void function_options::set_print_latex_func(PyObject* f)
521 {
522 unsigned id = print_latex::get_class_info_static().options.get_id();
523 if (id >= print_dispatch_table.size())
524 print_dispatch_table.resize(id + 1);
525 print_dispatch_table[id] = print_funcp(f);
526 }
527
set_print_dflt_func(PyObject * f)528 void function_options::set_print_dflt_func(PyObject* f)
529 {
530 unsigned id = print_dflt::get_class_info_static().options.get_id();
531 if (id >= print_dispatch_table.size())
532 print_dispatch_table.resize(id + 1);
533 print_dispatch_table[id] = print_funcp(f);
534 }
535
536 /** This can be used as a hook for external applications. */
537 unsigned function::current_serial = 0;
538
539
540 registered_class_info function::reg_info = \
541 registered_class_info(registered_class_options("function",
542 "exprseq",
543 &function::tinfo_static));
544
545 const tinfo_static_t function::tinfo_static = {};
546
547 //////////
548 // default constructor
549 //////////
550
551 // public
552
function()553 function::function() : serial(0)
554 {
555 tinfo_key = &function::tinfo_static;
556 }
557
558 //////////
559 // other constructors
560 //////////
561
562 // public
563
function(unsigned ser)564 function::function(unsigned ser) : serial(ser)
565 {
566 tinfo_key = &function::tinfo_static;
567 }
568
569 // the following lines have been generated for max. 14 parameters
function(unsigned ser,const ex & param1)570 function::function(unsigned ser, const ex & param1)
571 : exprseq(param1), serial(ser)
572 {
573 tinfo_key = &function::tinfo_static;
574 }
function(unsigned ser,const ex & param1,const ex & param2)575 function::function(unsigned ser, const ex & param1, const ex & param2)
576 : exprseq(param1, param2), serial(ser)
577 {
578 tinfo_key = &function::tinfo_static;
579 }
function(unsigned ser,const ex & param1,const ex & param2,const ex & param3)580 function::function(unsigned ser, const ex & param1, const ex & param2, const ex & param3)
581 : exprseq(param1, param2, param3), serial(ser)
582 {
583 tinfo_key = &function::tinfo_static;
584 }
function(unsigned ser,const ex & param1,const ex & param2,const ex & param3,const ex & param4)585 function::function(unsigned ser, const ex & param1, const ex & param2, const ex & param3, const ex & param4)
586 : exprseq(param1, param2, param3, param4), serial(ser)
587 {
588 tinfo_key = &function::tinfo_static;
589 }
function(unsigned ser,const ex & param1,const ex & param2,const ex & param3,const ex & param4,const ex & param5)590 function::function(unsigned ser, const ex & param1, const ex & param2, const ex & param3, const ex & param4, const ex & param5)
591 : exprseq(param1, param2, param3, param4, param5), serial(ser)
592 {
593 tinfo_key = &function::tinfo_static;
594 }
function(unsigned ser,const ex & param1,const ex & param2,const ex & param3,const ex & param4,const ex & param5,const ex & param6)595 function::function(unsigned ser, const ex & param1, const ex & param2, const ex & param3, const ex & param4, const ex & param5, const ex & param6)
596 : exprseq(param1, param2, param3, param4, param5, param6), serial(ser)
597 {
598 tinfo_key = &function::tinfo_static;
599 }
600
601 // end of generated lines
602
function(unsigned ser,exprseq es)603 function::function(unsigned ser, exprseq es) : exprseq(std::move(es)), serial(ser)
604 {
605 tinfo_key = &function::tinfo_static;
606
607 // Force re-evaluation even if the exprseq was already evaluated
608 // (the exprseq copy constructor copies the flags)
609 clearflag(status_flags::evaluated);
610 }
611
function(unsigned ser,const exvector & v,bool discardable)612 function::function(unsigned ser, const exvector & v, bool discardable)
613 : exprseq(v,discardable), serial(ser)
614 {
615 tinfo_key = &function::tinfo_static;
616 }
617
function(unsigned ser,std::unique_ptr<exvector> vp)618 function::function(unsigned ser, std::unique_ptr<exvector> vp)
619 : exprseq(std::move(vp)), serial(ser)
620 {
621 tinfo_key = &function::tinfo_static;
622 }
623
624 //////////
625 // archiving
626 //////////
627
628 /** Construct object from archive_node. */
function(const archive_node & n,lst & sym_lst)629 function::function(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
630 {
631 // get python_func flag
632 // since python_func used to be a bool flag in the old days,
633 // in order to unarchive old format archives, we first check if the
634 // archive node contains a bool property
635 // if it doesn't we look for the unsigned property indicating which
636 // custom functions are defined in python
637 unsigned python_func;
638 bool old_python_func;
639 if (n.find_bool("python", old_python_func))
640 python_func = old_python_func ? 0xFFFF : 0;
641 else if(!n.find_unsigned("python", python_func))
642 throw std::runtime_error("function::function archive error: cannot read python_func flag");
643 std::string s;
644 if (python_func != 0u) {
645 // read the pickle from the archive
646 if (!n.find_string("pickle", s))
647 throw std::runtime_error("function::function archive error: cannot read pickled function");
648 // unpickle
649 PyObject* arg = Py_BuildValue("s#",s.c_str(), s.size());
650 PyObject* sfunc = py_funcs.py_loads(arg);
651 Py_DECREF(arg);
652 if (PyErr_Occurred() != nullptr) {
653 throw(std::runtime_error("function::function archive error: caught exception in py_loads"));
654 }
655 // get the serial of the new SFunction
656 unsigned int ser = py_funcs.py_get_serial_from_sfunction(sfunc);
657 if (PyErr_Occurred() != nullptr) {
658 throw(std::runtime_error("function::function archive error: cannot get serial from SFunction"));
659 }
660 // set serial
661 serial = ser;
662 } else { // otherwise
663 // Find serial number by function name
664 if (n.find_string("name", s)) {
665 unsigned int ser = 0;
666 unsigned int nargs = seq.size();
667 for (const auto & elem : registered_functions()) {
668 if (s == elem.name && nargs == elem.nparams) {
669 serial = ser;
670 return;
671 }
672 ++ser;
673 }
674 // if the name is not already in the registry, we are
675 // unarchiving a SymbolicFunction without any custom methods
676 // Call Python to create a new symbolic function with name s
677 // and get the serial of this new SymbolicFunction
678 ser = py_funcs.py_get_serial_for_new_sfunction(s, nargs);
679 if (PyErr_Occurred() != nullptr) {
680 throw(std::runtime_error("function::function archive error: cannot create new symbolic function " + s));
681 }
682 serial = ser;
683 //throw (std::runtime_error("unknown function '" + s + "' in archive"));
684 } else
685 throw (std::runtime_error("unnamed function in archive"));
686 }
687 }
688
689 /** Unarchive the object. */
unarchive(const archive_node & n,lst & sym_lst)690 ex function::unarchive(const archive_node &n, lst &sym_lst)
691 {
692 return (new function(n, sym_lst))->setflag(status_flags::dynallocated);
693 }
694
695 /** Archive the object. */
archive(archive_node & n) const696 void function::archive(archive_node &n) const
697 {
698 inherited::archive(n);
699 GINAC_ASSERT(serial < registered_functions().size());
700 // we use Python's pickling mechanism to archive symbolic functions
701 // with customized methods defined in Python. Symbolic functions
702 // defined from c++ or those without custom methods are archived
703 // directly, without calling Python. The python_func flag indicates if
704 // we should use the python unpickling mechanism, or the regular
705 // unarchiving for c++ functions.
706 unsigned python_func = registered_functions()[serial].python_func;
707 if (python_func != 0u) {
708 n.add_unsigned("python", python_func);
709 // find the corresponding SFunction object
710 PyObject* sfunc = py_funcs.py_get_sfunction_from_serial(serial);
711 if (PyErr_Occurred() != nullptr) {
712 throw(std::runtime_error("function::archive cannot get serial from SFunction"));
713 }
714 // call python to pickle it
715 std::string* pickled = py_funcs.py_dumps(sfunc);
716 if (PyErr_Occurred() != nullptr) {
717 throw(std::runtime_error("function::archive py_dumps raised exception"));
718 }
719 // store the pickle in the archive
720 n.add_string("pickle", *pickled);
721 delete pickled;
722 } else {
723 n.add_unsigned("python", 0);
724 n.add_string("name", registered_functions()[serial].name);
725 }
726 }
727
728 //////////
729 // functions overriding virtual functions from base classes
730 //////////
731
732 // public
733
print(const print_context & c,unsigned level) const734 void function::print(const print_context & c, unsigned level) const
735 {
736 GINAC_ASSERT(serial<registered_functions().size());
737 // Dynamically dispatch on print_context type
738 const print_context_class_info *pc_info = &c.get_class_info();
739 if (serial >= static_cast<unsigned>(py_funcs.py_get_ginac_serial())) {
740 //convert arguments to a PyTuple of Expressions
741 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
742
743 std::string* sout;
744 if (is_a<print_latex>(c)) {
745 sout = py_funcs.py_latex_function(serial, args);
746 if (PyErr_Occurred() != nullptr) {
747 throw(std::runtime_error("function::print(): python print function raised exception"));
748 }
749 c.s << *sout;
750 c.s.flush();
751 }
752 else if (is_a<print_tree>(c)) {
753 sout = py_funcs.py_print_function(serial, args);
754 if (PyErr_Occurred() != nullptr) {
755 throw(std::runtime_error("function::print(): python print function raised exception"));
756 }
757 std::string fname = sout->substr(0, sout->find_first_of('('));
758 c.s << std::string(level, ' ') << class_name() << " "
759 << fname << " @" << this << ", serial=" << serial
760 << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
761 << ", nops=" << nops()
762 << std::endl;
763 unsigned delta_indent = dynamic_cast<const print_tree &>(c).delta_indent;
764 for (const auto& term : seq)
765 term.print(c, level + delta_indent);
766 c.s << std::string(level + delta_indent, ' ') << "=====" << std::endl;
767 }
768 else {
769 sout = py_funcs.py_print_function(serial, args);
770 if (PyErr_Occurred() != nullptr) {
771 throw(std::runtime_error("function::print(): python print function raised exception"));
772 }
773 c.s << *sout;
774 c.s.flush();
775 }
776
777 delete sout;
778 Py_DECREF(args);
779 } else {
780
781 if (is_a<print_latex>(c)) {
782 PyObject* sfunc = py_funcs.py_get_sfunction_from_serial(serial);
783 if (PyObject_HasAttrString(sfunc, "_print_latex_")) {
784 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
785 std::string* sout;
786 sout = py_funcs.py_latex_function(serial, args);
787 if (PyErr_Occurred() != nullptr) {
788 throw(std::runtime_error("function::print(): python print function raised exception"));
789 }
790 c.s << *sout;
791 c.s.flush();
792 delete sout;
793 Py_DECREF(args);
794 return;
795 }
796 }
797
798 const function_options &opt = registered_functions()[serial];
799 const std::vector<print_funcp> &pdt = opt.print_dispatch_table;
800
801
802 next_context:
803 unsigned id = pc_info->options.get_id();
804 if (id >= pdt.size() || pdt[id] == nullptr) {
805
806 // Method not found, try parent print_context class
807 const print_context_class_info *parent_pc_info = pc_info->get_parent();
808 if (parent_pc_info != nullptr) {
809 pc_info = parent_pc_info;
810 goto next_context;
811 }
812
813 // Method still not found, use default output
814 if (is_a<print_tree>(c)) {
815
816 c.s << std::string(level, ' ') << class_name() << " "
817 << opt.name << " @" << this << ", serial=" << serial
818 << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
819 << ", nops=" << nops()
820 << std::endl;
821 unsigned delta_indent = dynamic_cast<const print_tree &>(c).delta_indent;
822 for (auto & elem : seq)
823 elem.print(c, level + delta_indent);
824 c.s << std::string(level + delta_indent, ' ') << "=====" << std::endl;
825
826 } else if (is_a<print_latex>(c)) {
827 c.s << opt.TeX_name;
828 printseq(c, "\\left(", ',', "\\right)", exprseq::precedence(), function::precedence());
829 } else {
830 c.s << opt.name;
831 printseq(c, "(", ',', ")", exprseq::precedence(), function::precedence());
832 }
833
834 } else {
835
836 // Method found, call it
837 current_serial = serial;
838 if (opt.print_use_exvector_args)
839 (reinterpret_cast<print_funcp_exvector>(pdt[id]))(seq, c);
840 else
841 switch (opt.nparams) {
842 // the following lines have been generated for max. 14 parameters
843 case 1:
844 (reinterpret_cast<print_funcp_1>(pdt[id]))(seq[1 - 1], c);
845 break;
846 case 2:
847 (reinterpret_cast<print_funcp_2>(pdt[id]))(seq[1 - 1], seq[2 - 1], c);
848 break;
849 case 3:
850 (reinterpret_cast<print_funcp_3>(pdt[id]))(seq[1 - 1], seq[2 - 1], seq[3 - 1], c);
851 break;
852
853 // end of generated lines
854 default:
855 throw(std::logic_error("function::print(): invalid nparams"));
856 }
857 }
858 }
859 }
860
expand(unsigned options) const861 ex function::expand(unsigned options) const
862 {
863 return inherited::expand(options);
864 }
865
eval(int level) const866 ex function::eval(int level) const
867 {
868 if (level>1) {
869 // first evaluate children, then we will end up here again
870 return function(serial,evalchildren(level));
871 }
872
873 GINAC_ASSERT(serial<registered_functions().size());
874 const function_options &opt = registered_functions()[serial];
875
876 bool use_remember = opt.use_remember;
877 ex eval_result;
878 if (use_remember && lookup_remember_table(eval_result)) {
879 return eval_result;
880 }
881 current_serial = serial;
882
883 if (opt.eval_f==nullptr)
884 return this->hold();
885 eval_funcp eval_f;
886 if (opt.pynac_eval_f == nullptr)
887 eval_f = opt.eval_f;
888 else
889 eval_f = opt.pynac_eval_f;
890
891 if (opt.pynac_eval_f == nullptr
892 and (opt.python_func & function_options::eval_python_f) != 0u) {
893 // convert seq to a PyTuple of Expressions
894 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
895 // call opt.eval_f with this list
896 PyObject* pyresult = PyObject_CallMethod(reinterpret_cast<PyObject*>(eval_f),
897 const_cast<char*>("_eval_"), const_cast<char*>("O"), args);
898 Py_DECREF(args);
899 if (pyresult == nullptr) {
900 throw(std::runtime_error("function::eval(): python function raised exception"));
901 }
902 if ( pyresult == Py_None ) {
903 return this->hold();
904 }
905 // convert output Expression to an ex
906 eval_result = py_funcs.pyExpression_to_ex(pyresult);
907 Py_DECREF(pyresult);
908 if (PyErr_Occurred() != nullptr) {
909 throw(std::runtime_error("function::eval(): python function (Expression_to_ex) raised exception"));
910 }
911 }
912 else if (opt.eval_use_exvector_args)
913 eval_result = (reinterpret_cast<eval_funcp_exvector>(eval_f))(seq);
914 else
915 switch (opt.nparams) {
916 // the following lines have been generated for max. 14 parameters
917 case 1:
918 eval_result = (reinterpret_cast<eval_funcp_1>(eval_f))(seq[1-1]);
919 break;
920 case 2:
921 eval_result = (reinterpret_cast<eval_funcp_2>(eval_f))(seq[1-1], seq[2-1]);
922 break;
923 case 3:
924 eval_result = (reinterpret_cast<eval_funcp_3>(eval_f))(seq[1-1], seq[2-1], seq[3-1]);
925 break;
926 case 6:
927 eval_result = (reinterpret_cast<eval_funcp_6>(eval_f))(seq[1-1], seq[2-1], seq[3-1], seq[4-1], seq[5-1], seq[6-1]);
928 break;
929
930 // end of generated lines
931 default:
932 throw(std::logic_error("function::eval(): invalid nparams"));
933 }
934 if (use_remember) {
935 store_remember_table(eval_result);
936 }
937 return eval_result;
938 }
939
evalf(int level,PyObject * kwds) const940 ex function::evalf(int level, PyObject* kwds) const
941 {
942 GINAC_ASSERT(serial<registered_functions().size());
943 const function_options &opt = registered_functions()[serial];
944
945 // Evaluate children first?
946 exvector eseq;
947 if (level == 1 || !(opt.evalf_params_first))
948 eseq = seq;
949 else if (level == -max_recursion_level)
950 throw(std::runtime_error("max recursion level reached"));
951 else {
952 eseq.reserve(seq.size());
953 --level;
954 for (const auto & elem : seq)
955 eseq.push_back(elem.evalf(level, kwds));
956 }
957
958 if (opt.evalf_f == nullptr) {
959 if (opt.nparams == 1 and is_exactly_a<numeric>(eseq[1-1])) {
960 const numeric& n = ex_to<numeric>(eseq[1-1]);
961 try {
962 return n.try_py_method(get_name());
963 }
964 catch (std::logic_error) {
965 try {
966 const numeric& nn = ex_to<numeric>(n.evalf()).try_py_method(get_name());
967 return nn.to_dict_parent(kwds);
968 }
969 catch (std::logic_error) {}
970 }
971 }
972 return function(serial,eseq).hold();
973 }
974 current_serial = serial;
975 if ((opt.python_func & function_options::evalf_python_f) != 0u) {
976 // convert seq to a PyTuple of Expressions
977 PyObject* args = py_funcs.exvector_to_PyTuple(eseq);
978 // call opt.evalf_f with this list
979 PyObject* pyresult = PyEval_CallObjectWithKeywords(
980 PyObject_GetAttrString(reinterpret_cast<PyObject*>(opt.evalf_f),
981 "_evalf_"), args, kwds);
982 Py_DECREF(args);
983 if (pyresult == nullptr) {
984 throw(std::runtime_error("function::evalf(): python function raised exception"));
985 }
986 // convert output Expression to an ex
987 ex result = py_funcs.pyExpression_to_ex(pyresult);
988 Py_DECREF(pyresult);
989 if (PyErr_Occurred() != nullptr) {
990 throw(std::runtime_error("function::evalf(): python function (pyExpression_to_ex) raised exception"));
991 }
992 return result;
993 }
994 if (opt.evalf_use_exvector_args)
995 return (reinterpret_cast<evalf_funcp_exvector>(opt.evalf_f))(seq, kwds);
996 switch (opt.nparams) {
997 // the following lines have been generated for max. 14 parameters
998 case 1:
999 return (reinterpret_cast<evalf_funcp_1>(opt.evalf_f))(eseq[1-1], kwds);
1000 case 2:
1001 return (reinterpret_cast<evalf_funcp_2>(opt.evalf_f))(eseq[1-1], eseq[2-1], kwds);
1002 case 3:
1003 return (reinterpret_cast<evalf_funcp_3>(opt.evalf_f))(eseq[1-1], eseq[2-1], eseq[3-1], kwds);
1004 case 6:
1005 return (reinterpret_cast<evalf_funcp_6>(opt.evalf_f))(eseq[1-1], eseq[2-1], eseq[3-1], eseq[4-1], eseq[5-1], eseq[6-1], kwds);
1006
1007 // end of generated lines
1008 }
1009 throw(std::logic_error("function::evalf(): invalid nparams"));
1010 }
1011
calchash() const1012 long function::calchash() const
1013 {
1014 long v = golden_ratio_hash(golden_ratio_hash((intptr_t)tinfo()) ^ serial);
1015 for (size_t i=0; i<nops(); i++) {
1016 v = rotate_left(v);
1017 v ^= this->op(i).gethash();
1018 }
1019
1020 if (is_evaluated()) {
1021 setflag(status_flags::hash_calculated);
1022 hashvalue = v;
1023 }
1024 return v;
1025 }
1026
thiscontainer(const exvector & v) const1027 ex function::thiscontainer(const exvector & v) const
1028 {
1029 return function(serial, v);
1030 }
1031
thiscontainer(std::unique_ptr<exvector> vp) const1032 ex function::thiscontainer(std::unique_ptr<exvector> vp) const
1033 {
1034 return function(serial, std::move(vp));
1035 }
1036
1037 /** Implementation of ex::series for functions.
1038 * @see ex::series */
series(const relational & r,int order,unsigned options) const1039 ex function::series(const relational & r, int order, unsigned options) const
1040 {
1041 GINAC_ASSERT(serial<registered_functions().size());
1042 const function_options &opt = registered_functions()[serial];
1043
1044 if (opt.series_f==nullptr) {
1045 return basic::series(r, order);
1046 }
1047 ex res;
1048 current_serial = serial;
1049 if ((opt.python_func & function_options::series_python_f) != 0u) {
1050 // convert seq to a PyTuple of Expressions
1051 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1052 // create a dictionary {'order': order, 'options':options}
1053 PyObject* kwds = Py_BuildValue("{s:i,s:I}","order",order,"options",options);
1054 // add variable to expand for as a keyword argument
1055 PyDict_SetItemString(kwds, "var", py_funcs.ex_to_pyExpression(r.lhs()));
1056 // add the point of expansion as a keyword argument
1057 PyDict_SetItemString(kwds, "at", py_funcs.ex_to_pyExpression(r.rhs()));
1058 // call opt.series_f with this list
1059 PyObject* pyresult = PyEval_CallObjectWithKeywords(
1060 PyObject_GetAttrString(reinterpret_cast<PyObject*>(opt.series_f),
1061 "_series_"), args, kwds);
1062 Py_DECREF(args);
1063 Py_DECREF(kwds);
1064 if (pyresult == nullptr) {
1065 throw(std::runtime_error("function::series(): python function raised exception"));
1066 }
1067 // convert output Expression to an ex
1068 ex result = py_funcs.pyExpression_to_ex(pyresult);
1069 Py_DECREF(pyresult);
1070 if (PyErr_Occurred() != nullptr) {
1071 throw(std::runtime_error("function::series(): python function (pyExpression_to_ex) raised exception"));
1072 }
1073 return result;
1074 }
1075 if (opt.series_use_exvector_args) {
1076 try {
1077 res = (reinterpret_cast<series_funcp_exvector>(opt.series_f))(seq, r, order, options);
1078 } catch (do_taylor) {
1079 res = basic::series(r, order, options);
1080 }
1081 return res;
1082 }
1083 switch (opt.nparams) {
1084 // the following lines have been generated for max. 14 parameters
1085 case 1:
1086 try {
1087 res = (reinterpret_cast<series_funcp_1>(opt.series_f))(seq[1-1],r,order,options);
1088 } catch (do_taylor) {
1089 res = basic::series(r, order, options);
1090 }
1091 return res;
1092 case 2:
1093 try {
1094 res = (reinterpret_cast<series_funcp_2>(opt.series_f))(seq[1-1], seq[2-1],r,order,options);
1095 } catch (do_taylor) {
1096 res = basic::series(r, order, options);
1097 }
1098 return res;
1099 case 3:
1100 try {
1101 res = (reinterpret_cast<series_funcp_3>(opt.series_f))(seq[1-1], seq[2-1], seq[3-1],r,order,options);
1102 } catch (do_taylor) {
1103 res = basic::series(r, order, options);
1104 }
1105 return res;
1106
1107 // end of generated lines
1108 }
1109 throw(std::logic_error("function::series(): invalid nparams"));
1110 }
1111
1112
1113 /** Implementation of ex::subs for functions. */
subs(const exmap & m,unsigned options) const1114 ex function::subs(const exmap & m, unsigned options) const
1115 {
1116 GINAC_ASSERT(serial<registered_functions().size());
1117 const function_options & opt = registered_functions()[serial];
1118
1119 if ((opt.python_func & function_options::subs_python_f) != 0u) {
1120 // convert seq to a PyTuple of Expressions
1121 PyObject* args = py_funcs.subs_args_to_PyTuple(m, options, seq);
1122 // call opt.subs_f with this list
1123 PyObject* pyresult = PyObject_CallMethod(
1124 reinterpret_cast<PyObject*>(opt.subs_f),
1125 const_cast<char*>("_subs_"), const_cast<char*>("O"), args);
1126 Py_DECREF(args);
1127 if (pyresult == nullptr) {
1128 throw(std::runtime_error("function::subs(): python method (_subs_) raised exception"));
1129 }
1130 // convert output Expression to an ex
1131 ex result = py_funcs.pyExpression_to_ex(pyresult);
1132 Py_DECREF(pyresult);
1133 if (PyErr_Occurred() != nullptr) {
1134 throw(std::runtime_error("function::subs(): python function (pyExpression_to_ex) raised exception"));
1135 }
1136 return result;
1137 }
1138 if (opt.subs_f==nullptr)
1139 return exprseq::subs(m, options);
1140
1141 switch (opt.nparams) {
1142 case 1:
1143 return (reinterpret_cast<subs_funcp_1>(opt.subs_f))(m, seq[1-1]);
1144 case 2:
1145 return (reinterpret_cast<subs_funcp_2>(opt.subs_f))(m, seq[1-1], seq[2-1]);
1146 case 3:
1147 return (reinterpret_cast<subs_funcp_3>(opt.subs_f))(m, seq[1-1], seq[2-1], seq[3-1]);
1148
1149 // end of generated lines
1150 }
1151 throw(std::logic_error("function::subs(): invalid nparams"));
1152 }
1153
1154 /** Implementation of ex::conjugate for functions. */
conjugate() const1155 ex function::conjugate() const
1156 {
1157 GINAC_ASSERT(serial<registered_functions().size());
1158 const function_options & opt = registered_functions()[serial];
1159
1160 if (opt.conjugate_f==nullptr) {
1161 return conjugate_function(*this).hold();
1162 }
1163
1164 if ((opt.python_func & function_options::conjugate_python_f) != 0u) {
1165 // convert seq to a PyTuple of Expressions
1166 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1167 // call opt.conjugate_f with this list
1168 PyObject* pyresult = PyObject_CallMethod(
1169 reinterpret_cast<PyObject*>(opt.conjugate_f),
1170 const_cast<char*>("_conjugate_"), const_cast<char*>("O"), args);
1171 Py_DECREF(args);
1172 if (pyresult == nullptr) {
1173 throw(std::runtime_error("function::conjugate(): python function raised exception"));
1174 }
1175 // convert output Expression to an ex
1176 ex result = py_funcs.pyExpression_to_ex(pyresult);
1177 Py_DECREF(pyresult);
1178 if (PyErr_Occurred() != nullptr) {
1179 throw(std::runtime_error("function::conjugate(): python function (pyExpression_to_ex) raised exception"));
1180 }
1181 return result;
1182 }
1183 if (opt.conjugate_use_exvector_args) {
1184 return (reinterpret_cast<conjugate_funcp_exvector>(opt.conjugate_f))(seq);
1185 }
1186
1187 switch (opt.nparams) {
1188 // the following lines have been generated for max. 14 parameters
1189 case 1:
1190 return (reinterpret_cast<conjugate_funcp_1>(opt.conjugate_f))(seq[1-1]);
1191 case 2:
1192 return (reinterpret_cast<conjugate_funcp_2>(opt.conjugate_f))(seq[1-1], seq[2-1]);
1193 case 3:
1194 return (reinterpret_cast<conjugate_funcp_3>(opt.conjugate_f))(seq[1-1], seq[2-1], seq[3-1]);
1195
1196 // end of generated lines
1197 }
1198 throw(std::logic_error("function::conjugate(): invalid nparams"));
1199 }
1200
1201 /** Implementation of ex::real_part for functions. */
real_part() const1202 ex function::real_part() const
1203 {
1204 GINAC_ASSERT(serial<registered_functions().size());
1205 const function_options & opt = registered_functions()[serial];
1206
1207 if (opt.real_part_f==nullptr)
1208 return basic::real_part();
1209
1210 if ((opt.python_func & function_options::real_part_python_f) != 0u) {
1211 // convert seq to a PyTuple of Expressions
1212 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1213 // call opt.real_part_f with this list
1214 PyObject* pyresult = PyObject_CallMethod(reinterpret_cast<PyObject*>(opt.real_part_f),
1215 const_cast<char*>("_real_part_"), const_cast<char*>("O"), args);
1216 Py_DECREF(args);
1217 if (pyresult == nullptr) {
1218 throw(std::runtime_error("function::real_part(): python function raised exception"));
1219 }
1220 // convert output Expression to an ex
1221 ex result = py_funcs.pyExpression_to_ex(pyresult);
1222 Py_DECREF(pyresult);
1223 if (PyErr_Occurred() != nullptr) {
1224 throw(std::runtime_error("function::real_part(): python function (pyExpression_to_ex) raised exception"));
1225 }
1226 return result;
1227 }
1228 if (opt.real_part_use_exvector_args)
1229 return (reinterpret_cast<real_part_funcp_exvector>(opt.real_part_f))(seq);
1230
1231 switch (opt.nparams) {
1232 // the following lines have been generated for max. 14 parameters
1233 case 1:
1234 return (reinterpret_cast<real_part_funcp_1>(opt.real_part_f))(seq[1-1]);
1235 case 2:
1236 return (reinterpret_cast<real_part_funcp_2>(opt.real_part_f))(seq[1-1], seq[2-1]);
1237 case 3:
1238 return (reinterpret_cast<real_part_funcp_3>(opt.real_part_f))(seq[1-1], seq[2-1], seq[3-1]);
1239
1240 // end of generated lines
1241 }
1242 throw(std::logic_error("function::real_part(): invalid nparams"));
1243 }
1244
1245 /** Implementation of ex::imag_part for functions. */
imag_part() const1246 ex function::imag_part() const
1247 {
1248 GINAC_ASSERT(serial<registered_functions().size());
1249 const function_options & opt = registered_functions()[serial];
1250
1251 if (opt.imag_part_f==nullptr)
1252 return basic::imag_part();
1253
1254 if ((opt.python_func & function_options::imag_part_python_f) != 0u) {
1255 // convert seq to a PyTuple of Expressions
1256 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1257 // call opt.imag_part_f with this list
1258 PyObject* pyresult = PyObject_CallMethod(reinterpret_cast<PyObject*>(opt.imag_part_f),
1259 const_cast<char*>("_imag_part_"), const_cast<char*>("O"), args);
1260 Py_DECREF(args);
1261 if (pyresult == nullptr) {
1262 throw(std::runtime_error("function::imag_part(): python function raised exception"));
1263 }
1264 // convert output Expression to an ex
1265 ex result = py_funcs.pyExpression_to_ex(pyresult);
1266 Py_DECREF(pyresult);
1267 if (PyErr_Occurred() != nullptr) {
1268 throw(std::runtime_error("function::imag_part(): python function (pyExpression_to_ex) raised exception"));
1269 }
1270 return result;
1271 }
1272 if (opt.imag_part_use_exvector_args)
1273 return (reinterpret_cast<imag_part_funcp_exvector>(opt.imag_part_f))(seq);
1274
1275 switch (opt.nparams) {
1276 // the following lines have been generated for max. 14 parameters
1277 case 1:
1278 return (reinterpret_cast<imag_part_funcp_1>(opt.imag_part_f))(seq[1-1]);
1279 case 2:
1280 return (reinterpret_cast<imag_part_funcp_2>(opt.imag_part_f))(seq[1-1], seq[2-1]);
1281 case 3:
1282 return (reinterpret_cast<imag_part_funcp_3>(opt.imag_part_f))(seq[1-1], seq[2-1], seq[3-1]);
1283
1284 // end of generated lines
1285 }
1286 throw(std::logic_error("function::imag_part(): invalid nparams"));
1287 }
1288
1289 // protected
1290
1291 /** Implementation of ex::diff() for functions. It applies the chain rule,
1292 * except for the Order term function.
1293 * @see ex::diff */
derivative(const symbol & s) const1294 ex function::derivative(const symbol & s) const
1295 {
1296 ex result;
1297
1298 /*
1299 if (serial == Order_SERIAL::serial) {
1300 // Order Term function only differentiates the argument
1301 return Order(seq[0].diff(s));
1302 */
1303 GINAC_ASSERT(serial<registered_functions().size());
1304 const function_options &opt = registered_functions()[serial];
1305
1306 try {
1307 // Explicit derivation
1308 return expl_derivative(s);
1309 } catch (...) {}
1310
1311 // Check if we need to apply chain rule
1312 if (!opt.apply_chain_rule) {
1313 if (opt.derivative_f == nullptr)
1314 throw(std::runtime_error("function::derivative(): custom derivative function must be defined"));
1315
1316 if ((opt.python_func & function_options::derivative_python_f) != 0u) {
1317 // convert seq to a PyTuple of Expressions
1318 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1319 // create a dictionary {'diff_param': s}
1320 PyObject* symb = py_funcs.ex_to_pyExpression(s);
1321 PyObject* kwds = Py_BuildValue("{s:O}","diff_param",
1322 symb);
1323 // call opt.derivative_f with this list
1324 PyObject* pyresult = PyEval_CallObjectWithKeywords(
1325 PyObject_GetAttrString(
1326 reinterpret_cast<PyObject*>(opt.derivative_f),
1327 "_tderivative_"), args, kwds);
1328 Py_DECREF(symb);
1329 Py_DECREF(args);
1330 Py_DECREF(kwds);
1331 if (pyresult == nullptr) {
1332 throw(std::runtime_error("function::derivative(): python function raised exception"));
1333 }
1334 // convert output Expression to an ex
1335 result = py_funcs.pyExpression_to_ex(pyresult);
1336 Py_DECREF(pyresult);
1337 if (PyErr_Occurred() != nullptr) {
1338 throw(std::runtime_error("function::derivative(): python function (pyExpression_to_ex) raised exception"));
1339 }
1340 return result;
1341 }
1342 // C++ function
1343 if (!opt.derivative_use_exvector_args)
1344 throw(std::runtime_error("function::derivative(): cannot call C++ function without exvector args"));
1345
1346 return (reinterpret_cast<derivative_funcp_exvector_symbol>(opt.derivative_f))(seq, s);
1347
1348 }
1349 // Chain rule
1350 ex arg_diff;
1351 size_t num = seq.size();
1352 for (size_t i=0; i<num; i++) {
1353 arg_diff = seq[i].diff(s);
1354 // We apply the chain rule only when it makes sense. This is not
1355 // just for performance reasons but also to allow functions to
1356 // throw when differentiated with respect to one of its arguments
1357 // without running into trouble with our automatic full
1358 // differentiation:
1359 if (!arg_diff.is_zero())
1360 result += pderivative(i)*arg_diff;
1361 }
1362
1363 return result;
1364 }
1365
compare_same_type(const basic & other) const1366 int function::compare_same_type(const basic & other) const
1367 {
1368 GINAC_ASSERT(is_a<function>(other));
1369 const function & o = static_cast<const function &>(other);
1370
1371 if (serial != o.serial)
1372 return serial < o.serial ? -1 : 1;
1373
1374 return exprseq::compare_same_type(o);
1375 }
1376
1377
is_equal_same_type(const basic & other) const1378 bool function::is_equal_same_type(const basic & other) const
1379 {
1380 GINAC_ASSERT(is_a<function>(other));
1381 const function & o = static_cast<const function &>(other);
1382
1383 if (serial != o.serial)
1384 return false;
1385
1386 return exprseq::is_equal_same_type(o);
1387 }
1388
match_same_type(const basic & other) const1389 bool function::match_same_type(const basic & other) const
1390 {
1391 GINAC_ASSERT(is_a<function>(other));
1392 const function & o = static_cast<const function &>(other);
1393
1394 return serial == o.serial;
1395 }
1396
match(const ex & pattern,exmap & map) const1397 bool function::match(const ex & pattern, exmap& map) const
1398 {
1399 if (is_exactly_a<wildcard>(pattern)) {
1400 const auto& it = map.find(pattern);
1401 if (it != map.end())
1402 return is_equal(ex_to<basic>(it->second));
1403 map[pattern] = *this;
1404 return true;
1405 }
1406 if (not is_exactly_a<function>(pattern))
1407 return false;
1408 CMatcher cm(*this, pattern, map);
1409 const opt_exmap& m = cm.get();
1410 if (not m)
1411 return false;
1412 map = m.value();
1413 return true;
1414 }
1415
1416
return_type() const1417 unsigned function::return_type() const
1418 {
1419 GINAC_ASSERT(serial<registered_functions().size());
1420 const function_options &opt = registered_functions()[serial];
1421
1422 if (opt.use_return_type) {
1423 // Return type was explicitly specified
1424 return opt.return_type;
1425 }
1426 // Default behavior is to use the return type of the first
1427 // argument. Thus, exp() of a matrix behaves like a matrix, etc.
1428 if (seq.empty())
1429 return return_types::commutative;
1430
1431 return seq.begin()->return_type();
1432
1433 }
1434
return_type_tinfo() const1435 tinfo_t function::return_type_tinfo() const
1436 {
1437 GINAC_ASSERT(serial<registered_functions().size());
1438 const function_options &opt = registered_functions()[serial];
1439
1440 if (opt.use_return_type) {
1441 // Return type was explicitly specified
1442 return opt.return_type_tinfo;
1443 }
1444 // Default behavior is to use the return type of the first
1445 // argument. Thus, exp() of a matrix behaves like a matrix, etc.
1446 if (seq.empty())
1447 return this;
1448
1449 return seq.begin()->return_type_tinfo();
1450
1451 }
1452
1453 //////////
1454 // new virtual functions which can be overridden by derived classes
1455 //////////
1456
1457 // none
1458
1459 //////////
1460 // non-virtual functions in this class
1461 //////////
1462
1463 // protected
1464
pderivative(unsigned diff_param) const1465 ex function::pderivative(unsigned diff_param) const // partial differentiation
1466 {
1467 GINAC_ASSERT(serial<registered_functions().size());
1468 const function_options &opt = registered_functions()[serial];
1469
1470 // No derivative defined? Then return abstract derivative object
1471 if (opt.derivative_f == nullptr)
1472 return fderivative(serial, diff_param, seq);
1473
1474 current_serial = serial;
1475 if ((opt.python_func & function_options::derivative_python_f) != 0u) {
1476 // convert seq to a PyTuple of Expressions
1477 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1478 // create a dictionary {'diff_param': diff_param}
1479 PyObject* kwds = Py_BuildValue("{s:I}","diff_param",diff_param);
1480 // call opt.derivative_f with this list
1481 PyObject* pyresult = PyEval_CallObjectWithKeywords(
1482 PyObject_GetAttrString(reinterpret_cast<PyObject*>(opt.derivative_f),
1483 "_derivative_"), args, kwds);
1484 Py_DECREF(args);
1485 Py_DECREF(kwds);
1486 if (pyresult == nullptr) {
1487 throw(std::runtime_error("function::pderivative(): python function raised exception"));
1488 }
1489 if ( pyresult == Py_None ) {
1490 return fderivative(serial, diff_param, seq);
1491 }
1492 // convert output Expression to an ex
1493 ex result = py_funcs.pyExpression_to_ex(pyresult);
1494 Py_DECREF(pyresult);
1495 if (PyErr_Occurred() != nullptr) {
1496 throw(std::runtime_error("function::pderivative(): python function (pyExpression_to_ex) raised exception"));
1497 }
1498 return result;
1499 }
1500 if (opt.derivative_use_exvector_args)
1501 return (reinterpret_cast<derivative_funcp_exvector>(opt.derivative_f))(seq, diff_param);
1502 switch (opt.nparams) {
1503 // the following lines have been generated for max. 14 parameters
1504 case 1:
1505 return (reinterpret_cast<derivative_funcp_1>(opt.derivative_f))(seq[1-1],diff_param);
1506 case 2:
1507 return (reinterpret_cast<derivative_funcp_2>(opt.derivative_f))(seq[1-1], seq[2-1],diff_param);
1508 case 3:
1509 return (reinterpret_cast<derivative_funcp_3>(opt.derivative_f))(seq[1-1], seq[2-1], seq[3-1],diff_param);
1510 case 6:
1511 return (reinterpret_cast<derivative_funcp_6>(opt.derivative_f))(seq[1-1], seq[2-1], seq[3-1], seq[4-1], seq[5-1], seq[6-1], diff_param);
1512
1513 // end of generated lines
1514 }
1515 throw(std::logic_error("function::pderivative(): no diff function defined"));
1516 }
1517
expl_derivative(const symbol & s) const1518 ex function::expl_derivative(const symbol & s) const // explicit differentiation
1519 {
1520 GINAC_ASSERT(serial<registered_functions().size());
1521 const function_options &opt = registered_functions()[serial];
1522
1523 if (opt.expl_derivative_f) {
1524 // Invoke the defined explicit derivative function.
1525 current_serial = serial;
1526 if (opt.expl_derivative_use_exvector_args)
1527 return (reinterpret_cast<expl_derivative_funcp_exvector>(opt.expl_derivative_f))(seq, s);
1528 switch (opt.nparams) {
1529 // the following lines have been generated for max. 14 parameters
1530 case 1:
1531 return (reinterpret_cast<expl_derivative_funcp_1>(opt.expl_derivative_f))(seq[0], s);
1532 case 2:
1533 return (reinterpret_cast<expl_derivative_funcp_2>(opt.expl_derivative_f))(seq[0], seq[1], s);
1534 case 3:
1535 return (reinterpret_cast<expl_derivative_funcp_3>(opt.expl_derivative_f))(seq[0], seq[1], seq[2], s);
1536 }
1537 }
1538 // There is no fallback for explicit derivative.
1539 throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
1540 }
1541
power(const ex & power_param) const1542 ex function::power(const ex & power_param) const // power of function
1543 {
1544 GINAC_ASSERT(serial<registered_functions().size());
1545 const function_options &opt = registered_functions()[serial];
1546
1547 // No derivative defined? Then return abstract derivative object
1548 if (opt.power_f == nullptr)
1549 return (new GiNaC::power(*this, power_param))->setflag(status_flags::dynallocated |
1550 status_flags::evaluated);
1551
1552 current_serial = serial;
1553 if ((opt.python_func & function_options::power_python_f) != 0u) {
1554 // convert seq to a PyTuple of Expressions
1555 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1556 // create a dictionary {'power_param': power_param}
1557 PyObject* kwds = PyDict_New();
1558 PyDict_SetItemString(kwds, "power_param", py_funcs.ex_to_pyExpression(power_param));
1559 // call opt.power_f with this list
1560 PyObject* pyresult = PyEval_CallObjectWithKeywords(
1561 PyObject_GetAttrString(reinterpret_cast<PyObject*>(opt.power_f),
1562 "_power_"), args, kwds);
1563 Py_DECREF(args);
1564 Py_DECREF(kwds);
1565 if (pyresult == nullptr) {
1566 throw(std::runtime_error("function::power(): python function raised exception"));
1567 }
1568 // convert output Expression to an ex
1569 ex result = py_funcs.pyExpression_to_ex(pyresult);
1570 Py_DECREF(pyresult);
1571 if (PyErr_Occurred() != nullptr) {
1572 throw(std::runtime_error("function::power(): python function (pyExpression_to_ex) raised exception"));
1573 }
1574 return result;
1575 }
1576 if (opt.power_use_exvector_args)
1577 return (reinterpret_cast<power_funcp_exvector>(opt.power_f))(seq, power_param);
1578 switch (opt.nparams) {
1579 // the following lines have been generated for max. 14 parameters
1580 case 1:
1581 return (reinterpret_cast<power_funcp_1>(opt.power_f))(seq[1-1],power_param);
1582 case 2:
1583 return (reinterpret_cast<power_funcp_2>(opt.power_f))(seq[1-1], seq[2-1],power_param);
1584 case 3:
1585 return (reinterpret_cast<power_funcp_3>(opt.power_f))(seq[1-1], seq[2-1], seq[3-1],power_param);
1586
1587 // end of generated lines
1588 }
1589 throw(std::logic_error("function::power(): no power function defined"));
1590 }
1591
registered_functions()1592 std::vector<function_options> & function::registered_functions()
1593 {
1594 static auto rf = new std::vector<function_options>;
1595 return *rf;
1596 }
1597
lookup_remember_table(ex & result) const1598 bool function::lookup_remember_table(ex & result) const
1599 {
1600 return remember_table::remember_tables()[this->serial].lookup_entry(*this,result);
1601 }
1602
store_remember_table(ex const & result) const1603 void function::store_remember_table(ex const & result) const
1604 {
1605 remember_table::remember_tables()[this->serial].add_entry(*this,result);
1606 }
1607
1608 // public
1609
register_new(function_options const & opt)1610 unsigned function::register_new(function_options const & opt)
1611 {
1612 size_t same_name = 0;
1613 for (auto & elem : registered_functions()) {
1614 if (elem.name==opt.name) {
1615 ++same_name;
1616 }
1617 }
1618 if (same_name>=opt.functions_with_same_name) {
1619 // we do not throw an exception here because this code is
1620 // usually executed before main(), so the exception could not
1621 // caught anyhow
1622 //
1623 // SAGE note:
1624 // We suppress this warning since we allow a user to create
1625 // functions with same name, but different parameters
1626 // Sage SFunction class checks existence of a function before
1627 // allocating a new one.
1628 //
1629 //std::cerr << "WARNING: function name " << opt.name
1630 // << " already in use!" << std::endl;
1631 }
1632 registered_functions().push_back(opt);
1633 if (opt.use_remember) {
1634 remember_table::remember_tables().
1635 emplace_back(opt.remember_size,
1636 opt.remember_assoc_size,
1637 opt.remember_strategy);
1638 } else {
1639 remember_table::remember_tables().emplace_back();
1640 }
1641 return registered_functions().size()-1;
1642 }
1643
1644 /** Find serial number of function by name and number of parameters.
1645 * Throws exception if function was not found. */
find_function(const std::string & name,unsigned nparams)1646 unsigned function::find_function(const std::string &name, unsigned nparams)
1647 {
1648 unsigned serial = 0;
1649 for (const auto & elem : registered_functions()) {
1650 if (elem.get_name() == name && elem.get_nparams() == nparams)
1651 return serial;
1652 ++serial;
1653 }
1654 throw (std::runtime_error("no function '" + name + "' with " + ToString(nparams) + " parameters defined"));
1655 }
1656
1657 /** Return the print name of the function. */
get_name() const1658 std::string function::get_name() const
1659 {
1660 GINAC_ASSERT(serial<registered_functions().size());
1661 return registered_functions()[serial].name;
1662 }
1663
set_domain(unsigned d)1664 void function::set_domain(unsigned d)
1665 {
1666 domain = d;
1667 iflags.clear();
1668 switch (d) {
1669 case domain::complex:
1670 break;
1671 case domain::real:
1672 iflags.set(info_flags::real, true);
1673 break;
1674 case domain::positive:
1675 iflags.set(info_flags::real, true);
1676 iflags.set(info_flags::positive, true);
1677 break;
1678 case domain::integer:
1679 iflags.set(info_flags::real, true);
1680 iflags.set(info_flags::integer, true);
1681 break;
1682 }
1683 }
1684
has_function(const ex & x)1685 bool has_function(const ex & x)
1686 {
1687 if (is_exactly_a<function>(x))
1688 return true;
1689 for (size_t i=0; i<x.nops(); ++i)
1690 if (has_function(x.op(i)))
1691 return true;
1692
1693 return false;
1694 }
1695
has_symbol_or_function(const ex & x)1696 bool has_symbol_or_function(const ex & x)
1697 {
1698 if (is_exactly_a<symbol>(x) or is_exactly_a<function>(x))
1699 return true;
1700 for (size_t i=0; i<x.nops(); ++i)
1701 if (has_symbol_or_function(x.op(i)))
1702 return true;
1703
1704 return false;
1705 }
1706
has_oneof_function_helper(const ex & x,const std::map<unsigned,int> & m)1707 static bool has_oneof_function_helper(const ex& x,
1708 const std::map<unsigned,int>& m)
1709 {
1710 if (is_exactly_a<function>(x)
1711 and m.find(ex_to<function>(x).get_serial()) != m.end())
1712 return true;
1713 for (size_t i=0; i<x.nops(); ++i)
1714 if (has_oneof_function_helper(x.op(i), m))
1715 return true;
1716
1717 return false;
1718 }
1719
has_allof_function_helper(const ex & x,std::map<unsigned,int> & m)1720 static void has_allof_function_helper(const ex& x,
1721 std::map<unsigned,int>& m)
1722 {
1723 if (is_exactly_a<function>(x)) {
1724 unsigned ser = ex_to<function>(x).get_serial();
1725 if (m.find(ser) != m.end())
1726 m[ser] = 1;
1727 }
1728 for (size_t i=0; i<x.nops(); ++i)
1729 has_allof_function_helper(x.op(i), m);
1730 }
1731
has_function(const ex & x,const std::string & s)1732 bool has_function(const ex& x,
1733 const std::string& s)
1734 {
1735 std::map<unsigned,int> m;
1736 unsigned ser = 0;
1737 for (const auto & elem : function::registered_functions()) {
1738 if (s == elem.name)
1739 m[ser] = 0;
1740 ++ser;
1741 }
1742 if (m.empty())
1743 return false;
1744 return has_oneof_function_helper(x, m);
1745 }
1746
has_function(const ex & x,const std::vector<std::string> & v,bool all)1747 bool has_function(const ex& x,
1748 const std::vector<std::string>& v,
1749 bool all)
1750 {
1751 std::map<unsigned,int> m;
1752 for (const auto & s : v) {
1753 unsigned ser = 0;
1754 for (const auto & elem : function::registered_functions()) {
1755 if (s == elem.name)
1756 m[ser] = 0;
1757 ++ser;
1758 }
1759 }
1760 if (m.empty())
1761 return false;
1762 if (all) {
1763 has_allof_function_helper(x, m);
1764 for (const auto & p : m)
1765 // TODO: false negative if >1 func with same name
1766 if (p.second == 0)
1767 return false;
1768 return true;
1769 }
1770 return has_oneof_function_helper(x, m);
1771 }
1772
1773 } // namespace GiNaC
1774
1775