1 /** @file function.cpp
2  *
3  *  Implementation of class of symbolic functions. */
4 
5 /*
6  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22 
23 #define register
24 #define PY_SSIZE_T_CLEAN
25 #include <Python.h>
26 #include "py_funcs.h"
27 #include "function.h"
28 #include "operators.h"
29 #include "fderivative.h"
30 #include "ex.h"
31 #include "lst.h"
32 #include "print.h"
33 #include "power.h"
34 #include "relational.h"
35 #include "archive.h"
36 #include "inifcns.h"
37 #include "tostring.h"
38 #include "utils.h"
39 #include "remember.h"
40 #include "symbol.h"
41 #include "cmatcher.h"
42 #include "wildcard.h"
43 #include "expairseq.h"
44 
45 #include <iostream>
46 #include <string>
47 #include <stdexcept>
48 #include <list>
49 #include <limits>
50 #ifdef DO_GINAC_ASSERT
51 #  include <typeinfo>
52 #endif
53 
54 namespace GiNaC {
55 
56 //////////
57 // helper class function_options
58 //////////
59 
function_options()60 function_options::function_options()
61 {
62 	initialize();
63 }
64 
function_options(std::string const & n,std::string const & tn)65 function_options::function_options(std::string const & n, std::string const & tn)
66 {
67 	initialize();
68 	set_name(n, tn);
69 }
70 
function_options(std::string const & n,unsigned np)71 function_options::function_options(std::string const & n, unsigned np)
72 {
73         static std::string empty;
74 	initialize();
75         set_name(n, empty);
76 	nparams = np;
77 }
78 
~function_options()79 function_options::~function_options()
80 {
81 	// nothing to clean up at the moment
82 }
83 
initialize()84 void function_options::initialize()
85 {
86         static std::string s1("unnamed_function"), s2("\\mbox{unnamed}");
87 	set_name(s1, s2);
88 	nparams = 0;
89 	eval_f = real_part_f = imag_part_f = conjugate_f = derivative_f
90             = pynac_eval_f = expl_derivative_f = power_f = series_f
91             = subs_f = nullptr;
92 	evalf_f = nullptr;
93 	evalf_params_first = true;
94 	apply_chain_rule = true;
95 	use_return_type = false;
96 	eval_use_exvector_args = false;
97 	evalf_use_exvector_args = false;
98 	conjugate_use_exvector_args = false;
99 	real_part_use_exvector_args = false;
100 	imag_part_use_exvector_args = false;
101 	derivative_use_exvector_args = false;
102         expl_derivative_use_exvector_args = false;
103 	power_use_exvector_args = false;
104 	series_use_exvector_args = false;
105 	print_use_exvector_args = false;
106 	use_remember = false;
107 	python_func = 0;
108 	functions_with_same_name = 1;
109 	symtree = 0;
110 }
111 
set_name(std::string const & n,std::string const & tn)112 function_options & function_options::set_name(std::string const & n,
113                                               std::string const & tn)
114 {
115         name.assign(n);
116 	if (tn.empty()) {
117 		TeX_name.assign("{\\rm ");
118                 TeX_name += n;
119                 TeX_name.append("}");
120         }
121 	else
122 	        TeX_name.assign(tn);
123 	return *this;
124 }
125 
latex_name(std::string const & tn)126 function_options & function_options::latex_name(std::string const & tn)
127 {
128 	TeX_name = tn;
129 	return *this;
130 }
131 
132 // the following lines have been generated for max. 14 parameters
eval_func(eval_funcp_1 e)133 function_options & function_options::eval_func(eval_funcp_1 e)
134 {
135 	test_and_set_nparams(1);
136 	eval_f = eval_funcp(e);
137         pynac_eval_f = eval_f;
138 	return *this;
139 }
eval_func(eval_funcp_2 e)140 function_options & function_options::eval_func(eval_funcp_2 e)
141 {
142 	test_and_set_nparams(2);
143 	eval_f = eval_funcp(e);
144         pynac_eval_f = eval_f;
145 	return *this;
146 }
eval_func(eval_funcp_3 e)147 function_options & function_options::eval_func(eval_funcp_3 e)
148 {
149 	test_and_set_nparams(3);
150 	eval_f = eval_funcp(e);
151         pynac_eval_f = eval_f;
152 	return *this;
153 }
eval_func(eval_funcp_6 e)154 function_options & function_options::eval_func(eval_funcp_6 e)
155 {
156 	test_and_set_nparams(6);
157 	eval_f = eval_funcp(e);
158         pynac_eval_f = eval_f;
159 	return *this;
160 }
161 
evalf_func(evalf_funcp_1 ef)162 function_options & function_options::evalf_func(evalf_funcp_1 ef)
163 {
164 	test_and_set_nparams(1);
165 	evalf_f = evalf_funcp(ef);
166 	return *this;
167 }
evalf_func(evalf_funcp_2 ef)168 function_options & function_options::evalf_func(evalf_funcp_2 ef)
169 {
170 	test_and_set_nparams(2);
171 	evalf_f = evalf_funcp(ef);
172 	return *this;
173 }
evalf_func(evalf_funcp_3 ef)174 function_options & function_options::evalf_func(evalf_funcp_3 ef)
175 {
176 	test_and_set_nparams(3);
177 	evalf_f = evalf_funcp(ef);
178 	return *this;
179 }
evalf_func(evalf_funcp_6 ef)180 function_options & function_options::evalf_func(evalf_funcp_6 ef)
181 {
182 	test_and_set_nparams(6);
183 	evalf_f = evalf_funcp(ef);
184 	return *this;
185 }
186 
conjugate_func(conjugate_funcp_1 c)187 function_options & function_options::conjugate_func(conjugate_funcp_1 c)
188 {
189 	test_and_set_nparams(1);
190 	conjugate_f = conjugate_funcp(c);
191 	return *this;
192 }
conjugate_func(conjugate_funcp_2 c)193 function_options & function_options::conjugate_func(conjugate_funcp_2 c)
194 {
195 	test_and_set_nparams(2);
196 	conjugate_f = conjugate_funcp(c);
197 	return *this;
198 }
conjugate_func(conjugate_funcp_3 c)199 function_options & function_options::conjugate_func(conjugate_funcp_3 c)
200 {
201 	test_and_set_nparams(3);
202 	conjugate_f = conjugate_funcp(c);
203 	return *this;
204 }
205 
real_part_func(real_part_funcp_1 c)206 function_options & function_options::real_part_func(real_part_funcp_1 c)
207 {
208 	test_and_set_nparams(1);
209 	real_part_f = real_part_funcp(c);
210 	return *this;
211 }
real_part_func(real_part_funcp_2 c)212 function_options & function_options::real_part_func(real_part_funcp_2 c)
213 {
214 	test_and_set_nparams(2);
215 	real_part_f = real_part_funcp(c);
216 	return *this;
217 }
real_part_func(real_part_funcp_3 c)218 function_options & function_options::real_part_func(real_part_funcp_3 c)
219 {
220 	test_and_set_nparams(3);
221 	real_part_f = real_part_funcp(c);
222 	return *this;
223 }
224 
imag_part_func(imag_part_funcp_1 c)225 function_options & function_options::imag_part_func(imag_part_funcp_1 c)
226 {
227 	test_and_set_nparams(1);
228 	imag_part_f = imag_part_funcp(c);
229 	return *this;
230 }
imag_part_func(imag_part_funcp_2 c)231 function_options & function_options::imag_part_func(imag_part_funcp_2 c)
232 {
233 	test_and_set_nparams(2);
234 	imag_part_f = imag_part_funcp(c);
235 	return *this;
236 }
imag_part_func(imag_part_funcp_3 c)237 function_options & function_options::imag_part_func(imag_part_funcp_3 c)
238 {
239 	test_and_set_nparams(3);
240 	imag_part_f = imag_part_funcp(c);
241 	return *this;
242 }
243 
derivative_func(derivative_funcp_1 d)244 function_options & function_options::derivative_func(derivative_funcp_1 d)
245 {
246 	test_and_set_nparams(1);
247 	derivative_f = derivative_funcp(d);
248 	return *this;
249 }
derivative_func(derivative_funcp_2 d)250 function_options & function_options::derivative_func(derivative_funcp_2 d)
251 {
252 	test_and_set_nparams(2);
253 	derivative_f = derivative_funcp(d);
254 	return *this;
255 }
derivative_func(derivative_funcp_3 d)256 function_options & function_options::derivative_func(derivative_funcp_3 d)
257 {
258 	test_and_set_nparams(3);
259 	derivative_f = derivative_funcp(d);
260 	return *this;
261 }
derivative_func(derivative_funcp_6 d)262 function_options & function_options::derivative_func(derivative_funcp_6 d)
263 {
264 	test_and_set_nparams(6);
265 	derivative_f = derivative_funcp(d);
266 	return *this;
267 }
268 
expl_derivative_func(expl_derivative_funcp_1 d)269 function_options & function_options::expl_derivative_func(expl_derivative_funcp_1 d)
270 {
271 	test_and_set_nparams(1);
272 	expl_derivative_f = expl_derivative_funcp(d);
273 	return *this;
274 }
expl_derivative_func(expl_derivative_funcp_2 d)275 function_options & function_options::expl_derivative_func(expl_derivative_funcp_2 d)
276 {
277 	test_and_set_nparams(2);
278 	expl_derivative_f = expl_derivative_funcp(d);
279 	return *this;
280 }
expl_derivative_func(expl_derivative_funcp_3 d)281 function_options & function_options::expl_derivative_func(expl_derivative_funcp_3 d)
282 {
283 	test_and_set_nparams(3);
284 	expl_derivative_f = expl_derivative_funcp(d);
285 	return *this;
286 }
287 
power_func(power_funcp_1 d)288 function_options & function_options::power_func(power_funcp_1 d)
289 {
290 	test_and_set_nparams(1);
291 	power_f = power_funcp(d);
292 	return *this;
293 }
power_func(power_funcp_2 d)294 function_options & function_options::power_func(power_funcp_2 d)
295 {
296 	test_and_set_nparams(2);
297 	power_f = power_funcp(d);
298 	return *this;
299 }
power_func(power_funcp_3 d)300 function_options & function_options::power_func(power_funcp_3 d)
301 {
302 	test_and_set_nparams(3);
303 	power_f = power_funcp(d);
304 	return *this;
305 }
306 
series_func(series_funcp_1 s)307 function_options & function_options::series_func(series_funcp_1 s)
308 {
309 	test_and_set_nparams(1);
310 	series_f = series_funcp(s);
311 	return *this;
312 }
series_func(series_funcp_2 s)313 function_options & function_options::series_func(series_funcp_2 s)
314 {
315 	test_and_set_nparams(2);
316 	series_f = series_funcp(s);
317 	return *this;
318 }
series_func(series_funcp_3 s)319 function_options & function_options::series_func(series_funcp_3 s)
320 {
321 	test_and_set_nparams(3);
322 	series_f = series_funcp(s);
323 	return *this;
324 }
325 
subs_func(subs_funcp_1 s)326 function_options & function_options::subs_func(subs_funcp_1 s)
327 {
328 	test_and_set_nparams(1);
329 	subs_f = subs_funcp(s);
330 	return *this;
331 }
subs_func(subs_funcp_2 s)332 function_options & function_options::subs_func(subs_funcp_2 s)
333 {
334 	test_and_set_nparams(2);
335 	subs_f = subs_funcp(s);
336 	return *this;
337 }
subs_func(subs_funcp_3 s)338 function_options & function_options::subs_func(subs_funcp_3 s)
339 {
340 	test_and_set_nparams(3);
341 	subs_f = subs_funcp(s);
342 	return *this;
343 }
344 
345 // end of generated lines
346 
eval_func(eval_funcp_exvector e)347 function_options& function_options::eval_func(eval_funcp_exvector e)
348 {
349 	eval_use_exvector_args = true;
350 	eval_f = eval_funcp(e);
351         pynac_eval_f = eval_f;
352 	return *this;
353 }
evalf_func(evalf_funcp_exvector ef)354 function_options& function_options::evalf_func(evalf_funcp_exvector ef)
355 {
356 	evalf_use_exvector_args = true;
357 	evalf_f = evalf_funcp(ef);
358 	return *this;
359 }
conjugate_func(conjugate_funcp_exvector c)360 function_options& function_options::conjugate_func(conjugate_funcp_exvector c)
361 {
362 	conjugate_use_exvector_args = true;
363 	conjugate_f = conjugate_funcp(c);
364 	return *this;
365 }
real_part_func(real_part_funcp_exvector c)366 function_options& function_options::real_part_func(real_part_funcp_exvector c)
367 {
368 	real_part_use_exvector_args = true;
369 	real_part_f = real_part_funcp(c);
370 	return *this;
371 }
imag_part_func(imag_part_funcp_exvector c)372 function_options& function_options::imag_part_func(imag_part_funcp_exvector c)
373 {
374 	imag_part_use_exvector_args = true;
375 	imag_part_f = imag_part_funcp(c);
376 	return *this;
377 }
378 
derivative_func(derivative_funcp_exvector d)379 function_options& function_options::derivative_func(derivative_funcp_exvector d)
380 {
381 	derivative_use_exvector_args = true;
382 	derivative_f = derivative_funcp(d);
383 	return *this;
384 }
power_func(power_funcp_exvector d)385 function_options& function_options::power_func(power_funcp_exvector d)
386 {
387 	power_use_exvector_args = true;
388 	power_f = power_funcp(d);
389 	return *this;
390 }
series_func(series_funcp_exvector s)391 function_options& function_options::series_func(series_funcp_exvector s)
392 {
393 	series_use_exvector_args = true;
394 	series_f = series_funcp(s);
395 	return *this;
396 }
397 
derivative_func(derivative_funcp_exvector_symbol d)398 function_options& function_options::derivative_func(
399 		derivative_funcp_exvector_symbol d)
400 {
401 	derivative_use_exvector_args = true;
402 	derivative_f = derivative_funcp(d);
403 	return *this;
404 }
405 
eval_func(PyObject * e)406 function_options& function_options::eval_func(PyObject* e)
407 {
408 	python_func |= eval_python_f;
409 	eval_f = eval_funcp(e);
410 	return *this;
411 }
evalf_func(PyObject * ef)412 function_options& function_options::evalf_func(PyObject* ef)
413 {
414 	python_func |= evalf_python_f;
415 	evalf_f = evalf_funcp(ef);
416 	return *this;
417 }
conjugate_func(PyObject * c)418 function_options& function_options::conjugate_func(PyObject* c)
419 {
420 	python_func |= conjugate_python_f;
421 	conjugate_f = conjugate_funcp(c);
422 	return *this;
423 }
real_part_func(PyObject * c)424 function_options& function_options::real_part_func(PyObject* c)
425 {
426 	python_func |= real_part_python_f;
427 	real_part_f = real_part_funcp(c);
428 	return *this;
429 }
imag_part_func(PyObject * c)430 function_options& function_options::imag_part_func(PyObject* c)
431 {
432 	python_func |= imag_part_python_f;
433 	imag_part_f = imag_part_funcp(c);
434 	return *this;
435 }
436 
derivative_func(PyObject * d)437 function_options& function_options::derivative_func(PyObject* d)
438 {
439 	python_func |= derivative_python_f;
440 	derivative_f = derivative_funcp(d);
441 	return *this;
442 }
power_func(PyObject * d)443 function_options& function_options::power_func(PyObject* d)
444 {
445 	python_func |= power_python_f;
446 	power_f = power_funcp(d);
447 	return *this;
448 }
series_func(PyObject * s)449 function_options& function_options::series_func(PyObject* s)
450 {
451 	python_func |= series_python_f;
452 	series_f = series_funcp(s);
453 	return *this;
454 }
subs_func(PyObject * e)455 function_options& function_options::subs_func(PyObject* e)
456 {
457 	python_func |= subs_python_f;
458 	subs_f = subs_funcp(e);
459 	return *this;
460 }
461 
set_return_type(unsigned rt,tinfo_t rtt)462 function_options & function_options::set_return_type(unsigned rt, tinfo_t rtt)
463 {
464 	use_return_type = true;
465 	return_type = rt;
466 	return_type_tinfo = rtt;
467 	return *this;
468 }
469 
do_not_evalf_params()470 function_options & function_options::do_not_evalf_params()
471 {
472 	evalf_params_first = false;
473 	return *this;
474 }
475 
do_not_apply_chain_rule()476 function_options & function_options::do_not_apply_chain_rule()
477 {
478 	apply_chain_rule = false;
479 	return *this;
480 }
481 
remember(unsigned size,unsigned assoc_size,unsigned strategy)482 function_options & function_options::remember(unsigned size,
483                                               unsigned assoc_size,
484                                               unsigned strategy)
485 {
486 	use_remember = true;
487 	remember_size = size;
488 	remember_assoc_size = assoc_size;
489 	remember_strategy = strategy;
490 	return *this;
491 }
492 
overloaded(unsigned o)493 function_options & function_options::overloaded(unsigned o)
494 {
495 	functions_with_same_name = o;
496 	return *this;
497 }
498 
test_and_set_nparams(unsigned n)499 void function_options::test_and_set_nparams(unsigned n)
500 {
501 	if (nparams==0) {
502 		nparams = n;
503 	} else if (nparams!=n) {
504 		// we do not throw an exception here because this code is
505 		// usually executed before main(), so the exception could not
506 		// be caught anyhow
507 		std::cerr << "WARNING: " << name << "(): number of parameters ("
508 		          << n << ") differs from number set before ("
509 		          << nparams << ")" << std::endl;
510 	}
511 }
512 
set_print_func(unsigned id,print_funcp f)513 void function_options::set_print_func(unsigned id, print_funcp f)
514 {
515 	if (id >= print_dispatch_table.size())
516 		print_dispatch_table.resize(id + 1);
517 	print_dispatch_table[id] = f;
518 }
519 
set_print_latex_func(PyObject * f)520 void function_options::set_print_latex_func(PyObject* f)
521 {
522 	unsigned id = print_latex::get_class_info_static().options.get_id();
523 	if (id >= print_dispatch_table.size())
524 		print_dispatch_table.resize(id + 1);
525 	print_dispatch_table[id] = print_funcp(f);
526 }
527 
set_print_dflt_func(PyObject * f)528 void function_options::set_print_dflt_func(PyObject* f)
529 {
530 	unsigned id = print_dflt::get_class_info_static().options.get_id();
531 	if (id >= print_dispatch_table.size())
532 		print_dispatch_table.resize(id + 1);
533 	print_dispatch_table[id] = print_funcp(f);
534 }
535 
536 /** This can be used as a hook for external applications. */
537 unsigned function::current_serial = 0;
538 
539 
540 registered_class_info function::reg_info = \
541         registered_class_info(registered_class_options("function",
542                                 "exprseq",
543                                 &function::tinfo_static));
544 
545 const tinfo_static_t function::tinfo_static = {};
546 
547 //////////
548 // default constructor
549 //////////
550 
551 // public
552 
function()553 function::function() : serial(0)
554 {
555 	tinfo_key = &function::tinfo_static;
556 }
557 
558 //////////
559 // other constructors
560 //////////
561 
562 // public
563 
function(unsigned ser)564 function::function(unsigned ser) : serial(ser)
565 {
566 	tinfo_key = &function::tinfo_static;
567 }
568 
569 // the following lines have been generated for max. 14 parameters
function(unsigned ser,const ex & param1)570 function::function(unsigned ser, const ex & param1)
571 	: exprseq(param1), serial(ser)
572 {
573 	tinfo_key = &function::tinfo_static;
574 }
function(unsigned ser,const ex & param1,const ex & param2)575 function::function(unsigned ser, const ex & param1, const ex & param2)
576 	: exprseq(param1, param2), serial(ser)
577 {
578 	tinfo_key = &function::tinfo_static;
579 }
function(unsigned ser,const ex & param1,const ex & param2,const ex & param3)580 function::function(unsigned ser, const ex & param1, const ex & param2, const ex & param3)
581 	: exprseq(param1, param2, param3), serial(ser)
582 {
583 	tinfo_key = &function::tinfo_static;
584 }
function(unsigned ser,const ex & param1,const ex & param2,const ex & param3,const ex & param4)585 function::function(unsigned ser, const ex & param1, const ex & param2, const ex & param3, const ex & param4)
586 	: exprseq(param1, param2, param3, param4), serial(ser)
587 {
588 	tinfo_key = &function::tinfo_static;
589 }
function(unsigned ser,const ex & param1,const ex & param2,const ex & param3,const ex & param4,const ex & param5)590 function::function(unsigned ser, const ex & param1, const ex & param2, const ex & param3, const ex & param4, const ex & param5)
591 	: exprseq(param1, param2, param3, param4, param5), serial(ser)
592 {
593 	tinfo_key = &function::tinfo_static;
594 }
function(unsigned ser,const ex & param1,const ex & param2,const ex & param3,const ex & param4,const ex & param5,const ex & param6)595 function::function(unsigned ser, const ex & param1, const ex & param2, const ex & param3, const ex & param4, const ex & param5, const ex & param6)
596 	: exprseq(param1, param2, param3, param4, param5, param6), serial(ser)
597 {
598 	tinfo_key = &function::tinfo_static;
599 }
600 
601 // end of generated lines
602 
function(unsigned ser,exprseq es)603 function::function(unsigned ser, exprseq  es) : exprseq(std::move(es)), serial(ser)
604 {
605 	tinfo_key = &function::tinfo_static;
606 
607 	// Force re-evaluation even if the exprseq was already evaluated
608 	// (the exprseq copy constructor copies the flags)
609 	clearflag(status_flags::evaluated);
610 }
611 
function(unsigned ser,const exvector & v,bool discardable)612 function::function(unsigned ser, const exvector & v, bool discardable)
613   : exprseq(v,discardable), serial(ser)
614 {
615 	tinfo_key = &function::tinfo_static;
616 }
617 
function(unsigned ser,std::unique_ptr<exvector> vp)618 function::function(unsigned ser, std::unique_ptr<exvector> vp)
619   : exprseq(std::move(vp)), serial(ser)
620 {
621 	tinfo_key = &function::tinfo_static;
622 }
623 
624 //////////
625 // archiving
626 //////////
627 
628 /** Construct object from archive_node. */
function(const archive_node & n,lst & sym_lst)629 function::function(const archive_node &n, lst &sym_lst) : inherited(n, sym_lst)
630 {
631 	// get python_func flag
632 	// since python_func used to be a bool flag in the old days,
633 	// in order to unarchive old format archives, we first check if the
634 	// archive node contains a bool property
635 	// if it doesn't we look for the unsigned property indicating which
636 	// custom functions are defined in python
637 	unsigned python_func;
638 	bool old_python_func;
639 	if (n.find_bool("python", old_python_func))
640 		python_func = old_python_func ?  0xFFFF : 0;
641 	else if(!n.find_unsigned("python", python_func))
642 		throw std::runtime_error("function::function archive error: cannot read python_func flag");
643 	std::string s;
644 	if (python_func != 0u) {
645 		// read the pickle from the archive
646 		if (!n.find_string("pickle", s))
647 			throw std::runtime_error("function::function archive error: cannot read pickled function");
648 		// unpickle
649 		PyObject* arg = Py_BuildValue("s#",s.c_str(), s.size());
650 		PyObject* sfunc = py_funcs.py_loads(arg);
651 		Py_DECREF(arg);
652 		if (PyErr_Occurred() != nullptr) {
653 		    throw(std::runtime_error("function::function archive error: caught exception in py_loads"));
654 		}
655 		// get the serial of the new SFunction
656 		unsigned int ser = py_funcs.py_get_serial_from_sfunction(sfunc);
657 		if (PyErr_Occurred() != nullptr) {
658 		    throw(std::runtime_error("function::function archive error: cannot get serial from SFunction"));
659 		}
660 		// set serial
661 		serial = ser;
662 	} else { // otherwise
663 	// Find serial number by function name
664 	if (n.find_string("name", s)) {
665 		unsigned int ser = 0;
666 		unsigned int nargs = seq.size();
667                 for (const auto & elem : registered_functions()) {
668 			if (s == elem.name && nargs == elem.nparams) {
669 				serial = ser;
670 				return;
671 			}
672 			++ser;
673 		}
674 		// if the name is not already in the registry, we are
675 		// unarchiving a SymbolicFunction without any custom methods
676 		// Call Python to create a new symbolic function with name s
677 		// and get the serial of this new SymbolicFunction
678 		ser = py_funcs.py_get_serial_for_new_sfunction(s, nargs);
679 		if (PyErr_Occurred() != nullptr) {
680 		    throw(std::runtime_error("function::function archive error: cannot create new symbolic function " + s));
681 		}
682 		serial = ser;
683 		//throw (std::runtime_error("unknown function '" + s + "' in archive"));
684 	} else
685 		throw (std::runtime_error("unnamed function in archive"));
686 	}
687 }
688 
689 /** Unarchive the object. */
unarchive(const archive_node & n,lst & sym_lst)690 ex function::unarchive(const archive_node &n, lst &sym_lst)
691 {
692 	return (new function(n, sym_lst))->setflag(status_flags::dynallocated);
693 }
694 
695 /** Archive the object. */
archive(archive_node & n) const696 void function::archive(archive_node &n) const
697 {
698 	inherited::archive(n);
699 	GINAC_ASSERT(serial < registered_functions().size());
700 	// we use Python's pickling mechanism to archive symbolic functions
701 	// with customized methods defined in Python. Symbolic functions
702 	// defined from c++ or those without custom methods are archived
703 	// directly, without calling Python. The python_func flag indicates if
704 	// we should use the python unpickling mechanism, or the regular
705 	// unarchiving for c++ functions.
706 	unsigned python_func = registered_functions()[serial].python_func;
707 	if (python_func != 0u) {
708 		n.add_unsigned("python", python_func);
709 		// find the corresponding SFunction object
710 		PyObject* sfunc = py_funcs.py_get_sfunction_from_serial(serial);
711 		if (PyErr_Occurred() != nullptr) {
712 		    throw(std::runtime_error("function::archive cannot get serial from SFunction"));
713 		}
714 		// call python to pickle it
715 		std::string* pickled = py_funcs.py_dumps(sfunc);
716 		if (PyErr_Occurred() != nullptr) {
717 		    throw(std::runtime_error("function::archive py_dumps raised exception"));
718 		}
719 		// store the pickle in the archive
720 		n.add_string("pickle", *pickled);
721 		delete pickled;
722 	} else {
723 		n.add_unsigned("python", 0);
724 		n.add_string("name", registered_functions()[serial].name);
725 	}
726 }
727 
728 //////////
729 // functions overriding virtual functions from base classes
730 //////////
731 
732 // public
733 
print(const print_context & c,unsigned level) const734 void function::print(const print_context & c, unsigned level) const
735 {
736 	GINAC_ASSERT(serial<registered_functions().size());
737 	// Dynamically dispatch on print_context type
738 	const print_context_class_info *pc_info = &c.get_class_info();
739 	if (serial >= static_cast<unsigned>(py_funcs.py_get_ginac_serial())) {
740 		//convert arguments to a PyTuple of Expressions
741 		PyObject* args = py_funcs.exvector_to_PyTuple(seq);
742 
743 		std::string* sout;
744 		if (is_a<print_latex>(c)) {
745 			sout = py_funcs.py_latex_function(serial, args);
746                         if (PyErr_Occurred() != nullptr) {
747                                 throw(std::runtime_error("function::print(): python print function raised exception"));
748                         }
749                         c.s << *sout;
750                         c.s.flush();
751 		}
752 		else if (is_a<print_tree>(c)) {
753 			sout = py_funcs.py_print_function(serial, args);
754                         if (PyErr_Occurred() != nullptr) {
755                                 throw(std::runtime_error("function::print(): python print function raised exception"));
756                         }
757                         std::string fname = sout->substr(0, sout->find_first_of('('));
758 			c.s << std::string(level, ' ') << class_name() << " "
759 			    << fname << " @" << this << ", serial=" << serial
760 			    << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
761 			    << ", nops=" << nops()
762 			    << std::endl;
763 			unsigned delta_indent = dynamic_cast<const print_tree &>(c).delta_indent;
764 			for (const auto& term : seq)
765 				term.print(c, level + delta_indent);
766 			c.s << std::string(level + delta_indent, ' ') << "=====" << std::endl;
767 		}
768                 else {
769 			sout = py_funcs.py_print_function(serial, args);
770                         if (PyErr_Occurred() != nullptr) {
771                                 throw(std::runtime_error("function::print(): python print function raised exception"));
772                         }
773                         c.s << *sout;
774                         c.s.flush();
775 		}
776 
777 		delete sout;
778 		Py_DECREF(args);
779 	} else {
780 
781 		if (is_a<print_latex>(c)) {
782 		        PyObject* sfunc = py_funcs.py_get_sfunction_from_serial(serial);
783                         if (PyObject_HasAttrString(sfunc, "_print_latex_")) {
784                                 PyObject* args = py_funcs.exvector_to_PyTuple(seq);
785                                 std::string* sout;
786                                 sout = py_funcs.py_latex_function(serial, args);
787                                 if (PyErr_Occurred() != nullptr) {
788                                         throw(std::runtime_error("function::print(): python print function raised exception"));
789                                 }
790                                 c.s << *sout;
791                                 c.s.flush();
792                                 delete sout;
793                                 Py_DECREF(args);
794                                 return;
795                         }
796 		}
797 
798                 const function_options &opt = registered_functions()[serial];
799 		const std::vector<print_funcp> &pdt = opt.print_dispatch_table;
800 
801 
802 next_context:
803 	unsigned id = pc_info->options.get_id();
804 	if (id >= pdt.size() || pdt[id] == nullptr) {
805 
806 		// Method not found, try parent print_context class
807 		const print_context_class_info *parent_pc_info = pc_info->get_parent();
808 		if (parent_pc_info != nullptr) {
809 			pc_info = parent_pc_info;
810 			goto next_context;
811 		}
812 
813 		// Method still not found, use default output
814 		if (is_a<print_tree>(c)) {
815 
816 			c.s << std::string(level, ' ') << class_name() << " "
817 			    << opt.name << " @" << this << ", serial=" << serial
818 			    << std::hex << ", hash=0x" << hashvalue << ", flags=0x" << flags << std::dec
819 			    << ", nops=" << nops()
820 			    << std::endl;
821 			unsigned delta_indent = dynamic_cast<const print_tree &>(c).delta_indent;
822 			for (auto & elem : seq)
823 				elem.print(c, level + delta_indent);
824 			c.s << std::string(level + delta_indent, ' ') << "=====" << std::endl;
825 
826 		} else if (is_a<print_latex>(c)) {
827 			c.s << opt.TeX_name;
828 			printseq(c, "\\left(", ',', "\\right)", exprseq::precedence(), function::precedence());
829 		} else {
830 			c.s << opt.name;
831 			printseq(c, "(", ',', ")", exprseq::precedence(), function::precedence());
832 		}
833 
834 	} else {
835 
836 		// Method found, call it
837 		current_serial = serial;
838                 if (opt.print_use_exvector_args)
839                         (reinterpret_cast<print_funcp_exvector>(pdt[id]))(seq, c);
840                 else
841                         switch (opt.nparams) {
842                         // the following lines have been generated for max. 14 parameters
843                         case 1:
844                                 (reinterpret_cast<print_funcp_1>(pdt[id]))(seq[1 - 1], c);
845                                 break;
846                         case 2:
847                                 (reinterpret_cast<print_funcp_2>(pdt[id]))(seq[1 - 1], seq[2 - 1], c);
848                                 break;
849                         case 3:
850                                 (reinterpret_cast<print_funcp_3>(pdt[id]))(seq[1 - 1], seq[2 - 1], seq[3 - 1], c);
851                                 break;
852 
853                         // end of generated lines
854                         default:
855                                 throw(std::logic_error("function::print(): invalid nparams"));
856                         }
857         }
858 	}
859 }
860 
expand(unsigned options) const861 ex function::expand(unsigned options) const
862 {
863 	return inherited::expand(options);
864 }
865 
eval(int level) const866 ex function::eval(int level) const
867 {
868 	if (level>1) {
869 		// first evaluate children, then we will end up here again
870 		return function(serial,evalchildren(level));
871 	}
872 
873 	GINAC_ASSERT(serial<registered_functions().size());
874 	const function_options &opt = registered_functions()[serial];
875 
876 	bool use_remember = opt.use_remember;
877 	ex eval_result;
878 	if (use_remember && lookup_remember_table(eval_result)) {
879 		return eval_result;
880 	}
881 	current_serial = serial;
882 
883 	if (opt.eval_f==nullptr)
884 		return this->hold();
885         eval_funcp eval_f;
886         if (opt.pynac_eval_f == nullptr)
887                 eval_f = opt.eval_f;
888         else
889                 eval_f = opt.pynac_eval_f;
890 
891 	if (opt.pynac_eval_f == nullptr
892             and (opt.python_func & function_options::eval_python_f) != 0u) {
893 		// convert seq to a PyTuple of Expressions
894 		PyObject* args = py_funcs.exvector_to_PyTuple(seq);
895 		// call opt.eval_f with this list
896 		PyObject* pyresult = PyObject_CallMethod(reinterpret_cast<PyObject*>(eval_f),
897 				const_cast<char*>("_eval_"), const_cast<char*>("O"), args);
898 		Py_DECREF(args);
899 		if (pyresult == nullptr) {
900 			throw(std::runtime_error("function::eval(): python function raised exception"));
901 		}
902 		if ( pyresult == Py_None ) {
903 			return this->hold();
904 		}
905 		// convert output Expression to an ex
906 		eval_result = py_funcs.pyExpression_to_ex(pyresult);
907 		Py_DECREF(pyresult);
908 		if (PyErr_Occurred() != nullptr) {
909 			throw(std::runtime_error("function::eval(): python function (Expression_to_ex) raised exception"));
910 		}
911 	}
912 	else if (opt.eval_use_exvector_args)
913 		eval_result = (reinterpret_cast<eval_funcp_exvector>(eval_f))(seq);
914 	else
915 	switch (opt.nparams) {
916 		// the following lines have been generated for max. 14 parameters
917 	case 1:
918 		eval_result = (reinterpret_cast<eval_funcp_1>(eval_f))(seq[1-1]);
919 		break;
920 	case 2:
921 		eval_result = (reinterpret_cast<eval_funcp_2>(eval_f))(seq[1-1], seq[2-1]);
922 		break;
923 	case 3:
924 		eval_result = (reinterpret_cast<eval_funcp_3>(eval_f))(seq[1-1], seq[2-1], seq[3-1]);
925 		break;
926 	case 6:
927 		eval_result = (reinterpret_cast<eval_funcp_6>(eval_f))(seq[1-1], seq[2-1], seq[3-1], seq[4-1], seq[5-1], seq[6-1]);
928 		break;
929 
930 		// end of generated lines
931 	default:
932 		throw(std::logic_error("function::eval(): invalid nparams"));
933 	}
934 	if (use_remember) {
935 		store_remember_table(eval_result);
936 	}
937 	return eval_result;
938 }
939 
evalf(int level,PyObject * kwds) const940 ex function::evalf(int level, PyObject* kwds) const
941 {
942 	GINAC_ASSERT(serial<registered_functions().size());
943 	const function_options &opt = registered_functions()[serial];
944 
945 	// Evaluate children first?
946 	exvector eseq;
947 	if (level == 1 || !(opt.evalf_params_first))
948 		eseq = seq;
949 	else if (level == -max_recursion_level)
950 		throw(std::runtime_error("max recursion level reached"));
951 	else {
952 		eseq.reserve(seq.size());
953 		--level;
954                 for (const auto & elem : seq)
955 			eseq.push_back(elem.evalf(level, kwds));
956 	}
957 
958 	if (opt.evalf_f == nullptr) {
959                 if (opt.nparams == 1 and is_exactly_a<numeric>(eseq[1-1])) {
960                         const numeric& n = ex_to<numeric>(eseq[1-1]);
961                         try {
962                                 return n.try_py_method(get_name());
963                         }
964                         catch (std::logic_error) {
965                                 try {
966                                         const numeric& nn = ex_to<numeric>(n.evalf()).try_py_method(get_name());
967                                         return nn.to_dict_parent(kwds);
968                                 }
969                                 catch (std::logic_error) {}
970                         }
971                 }
972 		return function(serial,eseq).hold();
973 	}
974 	current_serial = serial;
975 	if ((opt.python_func & function_options::evalf_python_f) != 0u) {
976 		// convert seq to a PyTuple of Expressions
977 		PyObject* args = py_funcs.exvector_to_PyTuple(eseq);
978 		// call opt.evalf_f with this list
979 		PyObject* pyresult = PyEval_CallObjectWithKeywords(
980 			PyObject_GetAttrString(reinterpret_cast<PyObject*>(opt.evalf_f),
981 				"_evalf_"), args, kwds);
982 		Py_DECREF(args);
983 		if (pyresult == nullptr) {
984 			throw(std::runtime_error("function::evalf(): python function raised exception"));
985 		}
986 		// convert output Expression to an ex
987 		ex result = py_funcs.pyExpression_to_ex(pyresult);
988 		Py_DECREF(pyresult);
989 		if (PyErr_Occurred() != nullptr) {
990 			throw(std::runtime_error("function::evalf(): python function (pyExpression_to_ex) raised exception"));
991 		}
992 		return result;
993 	}
994 	if (opt.evalf_use_exvector_args)
995 		return (reinterpret_cast<evalf_funcp_exvector>(opt.evalf_f))(seq, kwds);
996 	switch (opt.nparams) {
997 		// the following lines have been generated for max. 14 parameters
998 	case 1:
999 		return (reinterpret_cast<evalf_funcp_1>(opt.evalf_f))(eseq[1-1], kwds);
1000 	case 2:
1001 		return (reinterpret_cast<evalf_funcp_2>(opt.evalf_f))(eseq[1-1], eseq[2-1], kwds);
1002 	case 3:
1003 		return (reinterpret_cast<evalf_funcp_3>(opt.evalf_f))(eseq[1-1], eseq[2-1], eseq[3-1], kwds);
1004 	case 6:
1005 		return (reinterpret_cast<evalf_funcp_6>(opt.evalf_f))(eseq[1-1], eseq[2-1], eseq[3-1], eseq[4-1], eseq[5-1], eseq[6-1], kwds);
1006 
1007 		// end of generated lines
1008 	}
1009 	throw(std::logic_error("function::evalf(): invalid nparams"));
1010 }
1011 
calchash() const1012 long function::calchash() const
1013 {
1014 	long v = golden_ratio_hash(golden_ratio_hash((intptr_t)tinfo()) ^ serial);
1015 	for (size_t i=0; i<nops(); i++) {
1016 		v = rotate_left(v);
1017 		v ^= this->op(i).gethash();
1018 	}
1019 
1020 	if (is_evaluated()) {
1021 		setflag(status_flags::hash_calculated);
1022 		hashvalue = v;
1023 	}
1024 	return v;
1025 }
1026 
thiscontainer(const exvector & v) const1027 ex function::thiscontainer(const exvector & v) const
1028 {
1029 	return function(serial, v);
1030 }
1031 
thiscontainer(std::unique_ptr<exvector> vp) const1032 ex function::thiscontainer(std::unique_ptr<exvector> vp) const
1033 {
1034 	return function(serial, std::move(vp));
1035 }
1036 
1037 /** Implementation of ex::series for functions.
1038  *  @see ex::series */
series(const relational & r,int order,unsigned options) const1039 ex function::series(const relational & r, int order, unsigned options) const
1040 {
1041 	GINAC_ASSERT(serial<registered_functions().size());
1042 	const function_options &opt = registered_functions()[serial];
1043 
1044 	if (opt.series_f==nullptr) {
1045 		return basic::series(r, order);
1046 	}
1047 	ex res;
1048 	current_serial = serial;
1049 	if ((opt.python_func & function_options::series_python_f) != 0u) {
1050 		// convert seq to a PyTuple of Expressions
1051 		PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1052 		// create a dictionary {'order': order, 'options':options}
1053 		PyObject* kwds = Py_BuildValue("{s:i,s:I}","order",order,"options",options);
1054 		// add variable to expand for as a keyword argument
1055 		PyDict_SetItemString(kwds, "var", py_funcs.ex_to_pyExpression(r.lhs()));
1056 		// add the point of expansion as a keyword argument
1057 		PyDict_SetItemString(kwds, "at", py_funcs.ex_to_pyExpression(r.rhs()));
1058 		// call opt.series_f with this list
1059 		PyObject* pyresult = PyEval_CallObjectWithKeywords(
1060 			PyObject_GetAttrString(reinterpret_cast<PyObject*>(opt.series_f),
1061 				"_series_"), args, kwds);
1062 		Py_DECREF(args);
1063 		Py_DECREF(kwds);
1064 		if (pyresult == nullptr) {
1065 			throw(std::runtime_error("function::series(): python function raised exception"));
1066 		}
1067 		// convert output Expression to an ex
1068 		ex result = py_funcs.pyExpression_to_ex(pyresult);
1069 		Py_DECREF(pyresult);
1070 		if (PyErr_Occurred() != nullptr) {
1071 			throw(std::runtime_error("function::series(): python function (pyExpression_to_ex) raised exception"));
1072 		}
1073 		return result;
1074 	}
1075 	if (opt.series_use_exvector_args) {
1076 		try {
1077 			res = (reinterpret_cast<series_funcp_exvector>(opt.series_f))(seq, r, order, options);
1078 		} catch (do_taylor) {
1079 			res = basic::series(r, order, options);
1080 		}
1081 		return res;
1082 	}
1083 	switch (opt.nparams) {
1084 		// the following lines have been generated for max. 14 parameters
1085 	case 1:
1086 		try {
1087 			res = (reinterpret_cast<series_funcp_1>(opt.series_f))(seq[1-1],r,order,options);
1088 		} catch (do_taylor) {
1089 			res = basic::series(r, order, options);
1090 		}
1091 		return res;
1092 	case 2:
1093 		try {
1094 			res = (reinterpret_cast<series_funcp_2>(opt.series_f))(seq[1-1], seq[2-1],r,order,options);
1095 		} catch (do_taylor) {
1096 			res = basic::series(r, order, options);
1097 		}
1098 		return res;
1099 	case 3:
1100 		try {
1101 			res = (reinterpret_cast<series_funcp_3>(opt.series_f))(seq[1-1], seq[2-1], seq[3-1],r,order,options);
1102 		} catch (do_taylor) {
1103 			res = basic::series(r, order, options);
1104 		}
1105 		return res;
1106 
1107 		// end of generated lines
1108 	}
1109 	throw(std::logic_error("function::series(): invalid nparams"));
1110 }
1111 
1112 
1113 /** Implementation of ex::subs for functions. */
subs(const exmap & m,unsigned options) const1114 ex function::subs(const exmap & m, unsigned options) const
1115 {
1116 	GINAC_ASSERT(serial<registered_functions().size());
1117 	const function_options & opt = registered_functions()[serial];
1118 
1119 	if ((opt.python_func & function_options::subs_python_f) != 0u) {
1120 		// convert seq to a PyTuple of Expressions
1121 		PyObject* args = py_funcs.subs_args_to_PyTuple(m, options, seq);
1122 		// call opt.subs_f with this list
1123 		PyObject* pyresult = PyObject_CallMethod(
1124 				reinterpret_cast<PyObject*>(opt.subs_f),
1125 				const_cast<char*>("_subs_"), const_cast<char*>("O"), args);
1126 		Py_DECREF(args);
1127 		if (pyresult == nullptr) {
1128 			throw(std::runtime_error("function::subs(): python method (_subs_) raised exception"));
1129 		}
1130 		// convert output Expression to an ex
1131 		ex result = py_funcs.pyExpression_to_ex(pyresult);
1132 		Py_DECREF(pyresult);
1133 		if (PyErr_Occurred() != nullptr) {
1134 			throw(std::runtime_error("function::subs(): python function (pyExpression_to_ex) raised exception"));
1135 		}
1136 		return result;
1137 	}
1138 	if (opt.subs_f==nullptr)
1139         	return exprseq::subs(m, options);
1140 
1141 	switch (opt.nparams) {
1142 	case 1:
1143 		return (reinterpret_cast<subs_funcp_1>(opt.subs_f))(m, seq[1-1]);
1144 	case 2:
1145 		return (reinterpret_cast<subs_funcp_2>(opt.subs_f))(m, seq[1-1], seq[2-1]);
1146 	case 3:
1147 		return (reinterpret_cast<subs_funcp_3>(opt.subs_f))(m, seq[1-1], seq[2-1], seq[3-1]);
1148 
1149 		// end of generated lines
1150 	}
1151 	throw(std::logic_error("function::subs(): invalid nparams"));
1152 }
1153 
1154 /** Implementation of ex::conjugate for functions. */
conjugate() const1155 ex function::conjugate() const
1156 {
1157 	GINAC_ASSERT(serial<registered_functions().size());
1158 	const function_options & opt = registered_functions()[serial];
1159 
1160 	if (opt.conjugate_f==nullptr) {
1161 		return conjugate_function(*this).hold();
1162 	}
1163 
1164 	if ((opt.python_func & function_options::conjugate_python_f) != 0u) {
1165 		// convert seq to a PyTuple of Expressions
1166 		PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1167 		// call opt.conjugate_f with this list
1168 		PyObject* pyresult = PyObject_CallMethod(
1169 				reinterpret_cast<PyObject*>(opt.conjugate_f),
1170 				const_cast<char*>("_conjugate_"), const_cast<char*>("O"), args);
1171 		Py_DECREF(args);
1172 		if (pyresult == nullptr) {
1173 			throw(std::runtime_error("function::conjugate(): python function raised exception"));
1174 		}
1175 		// convert output Expression to an ex
1176 		ex result = py_funcs.pyExpression_to_ex(pyresult);
1177 		Py_DECREF(pyresult);
1178 		if (PyErr_Occurred() != nullptr) {
1179 			throw(std::runtime_error("function::conjugate(): python function (pyExpression_to_ex) raised exception"));
1180 		}
1181 		return result;
1182 	}
1183 	if (opt.conjugate_use_exvector_args) {
1184 		return (reinterpret_cast<conjugate_funcp_exvector>(opt.conjugate_f))(seq);
1185 	}
1186 
1187 	switch (opt.nparams) {
1188 		// the following lines have been generated for max. 14 parameters
1189 	case 1:
1190 		return (reinterpret_cast<conjugate_funcp_1>(opt.conjugate_f))(seq[1-1]);
1191 	case 2:
1192 		return (reinterpret_cast<conjugate_funcp_2>(opt.conjugate_f))(seq[1-1], seq[2-1]);
1193 	case 3:
1194 		return (reinterpret_cast<conjugate_funcp_3>(opt.conjugate_f))(seq[1-1], seq[2-1], seq[3-1]);
1195 
1196 		// end of generated lines
1197 	}
1198 	throw(std::logic_error("function::conjugate(): invalid nparams"));
1199 }
1200 
1201 /** Implementation of ex::real_part for functions. */
real_part() const1202 ex function::real_part() const
1203 {
1204 	GINAC_ASSERT(serial<registered_functions().size());
1205 	const function_options & opt = registered_functions()[serial];
1206 
1207 	if (opt.real_part_f==nullptr)
1208 		return basic::real_part();
1209 
1210 	if ((opt.python_func & function_options::real_part_python_f) != 0u) {
1211 		// convert seq to a PyTuple of Expressions
1212 		PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1213 		// call opt.real_part_f with this list
1214 		PyObject* pyresult = PyObject_CallMethod(reinterpret_cast<PyObject*>(opt.real_part_f),
1215 				const_cast<char*>("_real_part_"), const_cast<char*>("O"), args);
1216 		Py_DECREF(args);
1217 		if (pyresult == nullptr) {
1218 			throw(std::runtime_error("function::real_part(): python function raised exception"));
1219 		}
1220 		// convert output Expression to an ex
1221 		ex result = py_funcs.pyExpression_to_ex(pyresult);
1222 		Py_DECREF(pyresult);
1223 		if (PyErr_Occurred() != nullptr) {
1224 			throw(std::runtime_error("function::real_part(): python function (pyExpression_to_ex) raised exception"));
1225 		}
1226 		return result;
1227 	}
1228 	if (opt.real_part_use_exvector_args)
1229 		return (reinterpret_cast<real_part_funcp_exvector>(opt.real_part_f))(seq);
1230 
1231 	switch (opt.nparams) {
1232 		// the following lines have been generated for max. 14 parameters
1233 	case 1:
1234 		return (reinterpret_cast<real_part_funcp_1>(opt.real_part_f))(seq[1-1]);
1235 	case 2:
1236 		return (reinterpret_cast<real_part_funcp_2>(opt.real_part_f))(seq[1-1], seq[2-1]);
1237 	case 3:
1238 		return (reinterpret_cast<real_part_funcp_3>(opt.real_part_f))(seq[1-1], seq[2-1], seq[3-1]);
1239 
1240 		// end of generated lines
1241 	}
1242 	throw(std::logic_error("function::real_part(): invalid nparams"));
1243 }
1244 
1245 /** Implementation of ex::imag_part for functions. */
imag_part() const1246 ex function::imag_part() const
1247 {
1248 	GINAC_ASSERT(serial<registered_functions().size());
1249 	const function_options & opt = registered_functions()[serial];
1250 
1251 	if (opt.imag_part_f==nullptr)
1252 		return basic::imag_part();
1253 
1254 	if ((opt.python_func & function_options::imag_part_python_f) != 0u) {
1255 		// convert seq to a PyTuple of Expressions
1256 		PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1257 		// call opt.imag_part_f with this list
1258 		PyObject* pyresult = PyObject_CallMethod(reinterpret_cast<PyObject*>(opt.imag_part_f),
1259 				const_cast<char*>("_imag_part_"), const_cast<char*>("O"), args);
1260 		Py_DECREF(args);
1261 		if (pyresult == nullptr) {
1262 			throw(std::runtime_error("function::imag_part(): python function raised exception"));
1263 		}
1264 		// convert output Expression to an ex
1265 		ex result = py_funcs.pyExpression_to_ex(pyresult);
1266 		Py_DECREF(pyresult);
1267 		if (PyErr_Occurred() != nullptr) {
1268 			throw(std::runtime_error("function::imag_part(): python function (pyExpression_to_ex) raised exception"));
1269 		}
1270 		return result;
1271 	}
1272 	if (opt.imag_part_use_exvector_args)
1273 		return (reinterpret_cast<imag_part_funcp_exvector>(opt.imag_part_f))(seq);
1274 
1275 	switch (opt.nparams) {
1276 		// the following lines have been generated for max. 14 parameters
1277 	case 1:
1278 		return (reinterpret_cast<imag_part_funcp_1>(opt.imag_part_f))(seq[1-1]);
1279 	case 2:
1280 		return (reinterpret_cast<imag_part_funcp_2>(opt.imag_part_f))(seq[1-1], seq[2-1]);
1281 	case 3:
1282 		return (reinterpret_cast<imag_part_funcp_3>(opt.imag_part_f))(seq[1-1], seq[2-1], seq[3-1]);
1283 
1284 		// end of generated lines
1285 	}
1286 	throw(std::logic_error("function::imag_part(): invalid nparams"));
1287 }
1288 
1289 // protected
1290 
1291 /** Implementation of ex::diff() for functions. It applies the chain rule,
1292  *  except for the Order term function.
1293  *  @see ex::diff */
derivative(const symbol & s) const1294 ex function::derivative(const symbol & s) const
1295 {
1296         ex result;
1297 
1298 	/*
1299 	if (serial == Order_SERIAL::serial) {
1300 		// Order Term function only differentiates the argument
1301 		return Order(seq[0].diff(s));
1302 		*/
1303 	GINAC_ASSERT(serial<registered_functions().size());
1304 	const function_options &opt = registered_functions()[serial];
1305 
1306         try {
1307                 // Explicit derivation
1308                 return expl_derivative(s);
1309         } catch (...) {}
1310 
1311 	// Check if we need to apply chain rule
1312 	if (!opt.apply_chain_rule) {
1313 		if (opt.derivative_f == nullptr)
1314 			throw(std::runtime_error("function::derivative(): custom derivative function must be defined"));
1315 
1316 		if ((opt.python_func & function_options::derivative_python_f) != 0u) {
1317 			// convert seq to a PyTuple of Expressions
1318 			PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1319 			// create a dictionary {'diff_param': s}
1320 			PyObject* symb = py_funcs.ex_to_pyExpression(s);
1321 			PyObject* kwds = Py_BuildValue("{s:O}","diff_param",
1322 					symb);
1323 			// call opt.derivative_f with this list
1324 			PyObject* pyresult = PyEval_CallObjectWithKeywords(
1325 				PyObject_GetAttrString(
1326 					reinterpret_cast<PyObject*>(opt.derivative_f),
1327 					"_tderivative_"), args, kwds);
1328 			Py_DECREF(symb);
1329 			Py_DECREF(args);
1330 			Py_DECREF(kwds);
1331 			if (pyresult == nullptr) {
1332 				throw(std::runtime_error("function::derivative(): python function raised exception"));
1333 			}
1334 			// convert output Expression to an ex
1335 			result = py_funcs.pyExpression_to_ex(pyresult);
1336 			Py_DECREF(pyresult);
1337 			if (PyErr_Occurred() != nullptr) {
1338 				throw(std::runtime_error("function::derivative(): python function (pyExpression_to_ex) raised exception"));
1339 			}
1340 			return result;
1341 		}
1342 		// C++ function
1343 		if (!opt.derivative_use_exvector_args)
1344 			throw(std::runtime_error("function::derivative(): cannot call C++ function without exvector args"));
1345 
1346 		return (reinterpret_cast<derivative_funcp_exvector_symbol>(opt.derivative_f))(seq, s);
1347 
1348 	}
1349         // Chain rule
1350         ex arg_diff;
1351         size_t num = seq.size();
1352         for (size_t i=0; i<num; i++) {
1353                 arg_diff = seq[i].diff(s);
1354                 // We apply the chain rule only when it makes sense.  This is not
1355                 // just for performance reasons but also to allow functions to
1356                 // throw when differentiated with respect to one of its arguments
1357                 // without running into trouble with our automatic full
1358                 // differentiation:
1359                 if (!arg_diff.is_zero())
1360                         result += pderivative(i)*arg_diff;
1361         }
1362 
1363 	return result;
1364 }
1365 
compare_same_type(const basic & other) const1366 int function::compare_same_type(const basic & other) const
1367 {
1368 	GINAC_ASSERT(is_a<function>(other));
1369 	const function & o = static_cast<const function &>(other);
1370 
1371 	if (serial != o.serial)
1372 		return serial < o.serial ? -1 : 1;
1373 
1374         return exprseq::compare_same_type(o);
1375 }
1376 
1377 
is_equal_same_type(const basic & other) const1378 bool function::is_equal_same_type(const basic & other) const
1379 {
1380 	GINAC_ASSERT(is_a<function>(other));
1381 	const function & o = static_cast<const function &>(other);
1382 
1383 	if (serial != o.serial)
1384 		return false;
1385 
1386 	return exprseq::is_equal_same_type(o);
1387 }
1388 
match_same_type(const basic & other) const1389 bool function::match_same_type(const basic & other) const
1390 {
1391 	GINAC_ASSERT(is_a<function>(other));
1392 	const function & o = static_cast<const function &>(other);
1393 
1394 	return serial == o.serial;
1395 }
1396 
match(const ex & pattern,exmap & map) const1397 bool function::match(const ex & pattern, exmap& map) const
1398 {
1399 	if (is_exactly_a<wildcard>(pattern)) {
1400                 const auto& it = map.find(pattern);
1401                 if (it != map.end())
1402 		        return is_equal(ex_to<basic>(it->second));
1403 		map[pattern] = *this;
1404 		return true;
1405 	}
1406         if (not is_exactly_a<function>(pattern))
1407                 return false;
1408         CMatcher cm(*this, pattern, map);
1409         const opt_exmap& m = cm.get();
1410         if (not m)
1411                 return false;
1412         map = m.value();
1413         return true;
1414 }
1415 
1416 
return_type() const1417 unsigned function::return_type() const
1418 {
1419 	GINAC_ASSERT(serial<registered_functions().size());
1420 	const function_options &opt = registered_functions()[serial];
1421 
1422 	if (opt.use_return_type) {
1423 		// Return type was explicitly specified
1424 		return opt.return_type;
1425 	}
1426         // Default behavior is to use the return type of the first
1427         // argument. Thus, exp() of a matrix behaves like a matrix, etc.
1428         if (seq.empty())
1429                 return return_types::commutative;
1430 
1431         return seq.begin()->return_type();
1432 
1433 }
1434 
return_type_tinfo() const1435 tinfo_t function::return_type_tinfo() const
1436 {
1437 	GINAC_ASSERT(serial<registered_functions().size());
1438 	const function_options &opt = registered_functions()[serial];
1439 
1440 	if (opt.use_return_type) {
1441 		// Return type was explicitly specified
1442 		return opt.return_type_tinfo;
1443 	}
1444 		// Default behavior is to use the return type of the first
1445 		// argument. Thus, exp() of a matrix behaves like a matrix, etc.
1446 		if (seq.empty())
1447 			return this;
1448 
1449 			return seq.begin()->return_type_tinfo();
1450 
1451 }
1452 
1453 //////////
1454 // new virtual functions which can be overridden by derived classes
1455 //////////
1456 
1457 // none
1458 
1459 //////////
1460 // non-virtual functions in this class
1461 //////////
1462 
1463 // protected
1464 
pderivative(unsigned diff_param) const1465 ex function::pderivative(unsigned diff_param) const // partial differentiation
1466 {
1467 	GINAC_ASSERT(serial<registered_functions().size());
1468 	const function_options &opt = registered_functions()[serial];
1469 
1470 	// No derivative defined? Then return abstract derivative object
1471 	if (opt.derivative_f == nullptr)
1472 		return fderivative(serial, diff_param, seq);
1473 
1474 	current_serial = serial;
1475 	if ((opt.python_func & function_options::derivative_python_f) != 0u) {
1476 		// convert seq to a PyTuple of Expressions
1477 		PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1478 		// create a dictionary {'diff_param': diff_param}
1479 		PyObject* kwds = Py_BuildValue("{s:I}","diff_param",diff_param);
1480 		// call opt.derivative_f with this list
1481 		PyObject* pyresult = PyEval_CallObjectWithKeywords(
1482 			PyObject_GetAttrString(reinterpret_cast<PyObject*>(opt.derivative_f),
1483 				"_derivative_"), args, kwds);
1484 		Py_DECREF(args);
1485 		Py_DECREF(kwds);
1486 		if (pyresult == nullptr) {
1487 			throw(std::runtime_error("function::pderivative(): python function raised exception"));
1488 		}
1489 		if ( pyresult == Py_None ) {
1490 			return fderivative(serial, diff_param, seq);
1491 		}
1492 		// convert output Expression to an ex
1493 		ex result = py_funcs.pyExpression_to_ex(pyresult);
1494 		Py_DECREF(pyresult);
1495 		if (PyErr_Occurred() != nullptr) {
1496 			throw(std::runtime_error("function::pderivative(): python function (pyExpression_to_ex) raised exception"));
1497 		}
1498 		return result;
1499 	}
1500 	if (opt.derivative_use_exvector_args)
1501 		return (reinterpret_cast<derivative_funcp_exvector>(opt.derivative_f))(seq, diff_param);
1502 	switch (opt.nparams) {
1503 		// the following lines have been generated for max. 14 parameters
1504 	case 1:
1505 		return (reinterpret_cast<derivative_funcp_1>(opt.derivative_f))(seq[1-1],diff_param);
1506 	case 2:
1507 		return (reinterpret_cast<derivative_funcp_2>(opt.derivative_f))(seq[1-1], seq[2-1],diff_param);
1508 	case 3:
1509 		return (reinterpret_cast<derivative_funcp_3>(opt.derivative_f))(seq[1-1], seq[2-1], seq[3-1],diff_param);
1510 	case 6:
1511 		return (reinterpret_cast<derivative_funcp_6>(opt.derivative_f))(seq[1-1], seq[2-1], seq[3-1], seq[4-1], seq[5-1], seq[6-1], diff_param);
1512 
1513 		// end of generated lines
1514 	}
1515 	throw(std::logic_error("function::pderivative(): no diff function defined"));
1516 }
1517 
expl_derivative(const symbol & s) const1518 ex function::expl_derivative(const symbol & s) const // explicit differentiation
1519 {
1520 	GINAC_ASSERT(serial<registered_functions().size());
1521 	const function_options &opt = registered_functions()[serial];
1522 
1523         if (opt.expl_derivative_f) {
1524 		// Invoke the defined explicit derivative function.
1525 		current_serial = serial;
1526 		if (opt.expl_derivative_use_exvector_args)
1527 			return (reinterpret_cast<expl_derivative_funcp_exvector>(opt.expl_derivative_f))(seq, s);
1528 		switch (opt.nparams) {
1529 			// the following lines have been generated for max. 14 parameters
1530 			case 1:
1531 				return (reinterpret_cast<expl_derivative_funcp_1>(opt.expl_derivative_f))(seq[0], s);
1532 			case 2:
1533 				return (reinterpret_cast<expl_derivative_funcp_2>(opt.expl_derivative_f))(seq[0], seq[1], s);
1534 			case 3:
1535 				return (reinterpret_cast<expl_derivative_funcp_3>(opt.expl_derivative_f))(seq[0], seq[1], seq[2], s);
1536 		}
1537 	}
1538 	// There is no fallback for explicit derivative.
1539 	throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
1540 }
1541 
power(const ex & power_param) const1542 ex function::power(const ex & power_param) const // power of function
1543 {
1544 	GINAC_ASSERT(serial<registered_functions().size());
1545 	const function_options &opt = registered_functions()[serial];
1546 
1547 	// No derivative defined? Then return abstract derivative object
1548 	if (opt.power_f == nullptr)
1549 		return (new GiNaC::power(*this, power_param))->setflag(status_flags::dynallocated |
1550 	                                               status_flags::evaluated);
1551 
1552 	current_serial = serial;
1553 	if ((opt.python_func & function_options::power_python_f) != 0u) {
1554 		// convert seq to a PyTuple of Expressions
1555 		PyObject* args = py_funcs.exvector_to_PyTuple(seq);
1556 		// create a dictionary {'power_param': power_param}
1557 		PyObject* kwds = PyDict_New();
1558 		PyDict_SetItemString(kwds, "power_param", py_funcs.ex_to_pyExpression(power_param));
1559 		// call opt.power_f with this list
1560 		PyObject* pyresult = PyEval_CallObjectWithKeywords(
1561 			PyObject_GetAttrString(reinterpret_cast<PyObject*>(opt.power_f),
1562 				"_power_"), args, kwds);
1563 		Py_DECREF(args);
1564 		Py_DECREF(kwds);
1565 		if (pyresult == nullptr) {
1566 			throw(std::runtime_error("function::power(): python function raised exception"));
1567 		}
1568 		// convert output Expression to an ex
1569 		ex result = py_funcs.pyExpression_to_ex(pyresult);
1570 		Py_DECREF(pyresult);
1571 		if (PyErr_Occurred() != nullptr) {
1572 			throw(std::runtime_error("function::power(): python function (pyExpression_to_ex) raised exception"));
1573 		}
1574 		return result;
1575 	}
1576 	if (opt.power_use_exvector_args)
1577 		return (reinterpret_cast<power_funcp_exvector>(opt.power_f))(seq,  power_param);
1578 	switch (opt.nparams) {
1579 		// the following lines have been generated for max. 14 parameters
1580 	case 1:
1581 		return (reinterpret_cast<power_funcp_1>(opt.power_f))(seq[1-1],power_param);
1582 	case 2:
1583 		return (reinterpret_cast<power_funcp_2>(opt.power_f))(seq[1-1], seq[2-1],power_param);
1584 	case 3:
1585 		return (reinterpret_cast<power_funcp_3>(opt.power_f))(seq[1-1], seq[2-1], seq[3-1],power_param);
1586 
1587 		// end of generated lines
1588 	}
1589 	throw(std::logic_error("function::power(): no power function defined"));
1590 }
1591 
registered_functions()1592 std::vector<function_options> & function::registered_functions()
1593 {
1594 	static auto  rf = new std::vector<function_options>;
1595 	return *rf;
1596 }
1597 
lookup_remember_table(ex & result) const1598 bool function::lookup_remember_table(ex & result) const
1599 {
1600 	return remember_table::remember_tables()[this->serial].lookup_entry(*this,result);
1601 }
1602 
store_remember_table(ex const & result) const1603 void function::store_remember_table(ex const & result) const
1604 {
1605 	remember_table::remember_tables()[this->serial].add_entry(*this,result);
1606 }
1607 
1608 // public
1609 
register_new(function_options const & opt)1610 unsigned function::register_new(function_options const & opt)
1611 {
1612 	size_t same_name = 0;
1613 	for (auto & elem : registered_functions()) {
1614 		if (elem.name==opt.name) {
1615 			++same_name;
1616 		}
1617 	}
1618 	if (same_name>=opt.functions_with_same_name) {
1619 		// we do not throw an exception here because this code is
1620 		// usually executed before main(), so the exception could not
1621 		// caught anyhow
1622 		//
1623 		// SAGE note:
1624 		// We suppress this warning since we allow a user to create
1625 		// functions with same name, but different parameters
1626 		// Sage SFunction class checks existence of a function before
1627 		// allocating a new one.
1628 		//
1629 		//std::cerr << "WARNING: function name " << opt.name
1630 		//          << " already in use!" << std::endl;
1631 	}
1632 	registered_functions().push_back(opt);
1633 	if (opt.use_remember) {
1634 		remember_table::remember_tables().
1635 			emplace_back(opt.remember_size,
1636 			                         opt.remember_assoc_size,
1637 			                         opt.remember_strategy);
1638 	} else {
1639 		remember_table::remember_tables().emplace_back();
1640 	}
1641 	return registered_functions().size()-1;
1642 }
1643 
1644 /** Find serial number of function by name and number of parameters.
1645  *  Throws exception if function was not found. */
find_function(const std::string & name,unsigned nparams)1646 unsigned function::find_function(const std::string &name, unsigned nparams)
1647 {
1648 	unsigned serial = 0;
1649         for (const auto & elem : registered_functions()) {
1650 		if (elem.get_name() == name && elem.get_nparams() == nparams)
1651 			return serial;
1652 		++serial;
1653 	}
1654 	throw (std::runtime_error("no function '" + name + "' with " + ToString(nparams) + " parameters defined"));
1655 }
1656 
1657 /** Return the print name of the function. */
get_name() const1658 std::string function::get_name() const
1659 {
1660 	GINAC_ASSERT(serial<registered_functions().size());
1661 	return registered_functions()[serial].name;
1662 }
1663 
set_domain(unsigned d)1664 void function::set_domain(unsigned d)
1665 {
1666         domain = d;
1667         iflags.clear();
1668         switch (d) {
1669         case domain::complex:
1670                 break;
1671         case domain::real:
1672                 iflags.set(info_flags::real, true);
1673                 break;
1674         case domain::positive:
1675                 iflags.set(info_flags::real, true);
1676                 iflags.set(info_flags::positive, true);
1677                 break;
1678         case domain::integer:
1679                 iflags.set(info_flags::real, true);
1680                 iflags.set(info_flags::integer, true);
1681                 break;
1682         }
1683 }
1684 
has_function(const ex & x)1685 bool has_function(const ex & x)
1686 {
1687 	if (is_exactly_a<function>(x))
1688 		return true;
1689 	for (size_t i=0; i<x.nops(); ++i)
1690 		if (has_function(x.op(i)))
1691 			return true;
1692 
1693 	return false;
1694 }
1695 
has_symbol_or_function(const ex & x)1696 bool has_symbol_or_function(const ex & x)
1697 {
1698 	if (is_exactly_a<symbol>(x) or is_exactly_a<function>(x))
1699 		return true;
1700 	for (size_t i=0; i<x.nops(); ++i)
1701 		if (has_symbol_or_function(x.op(i)))
1702 			return true;
1703 
1704 	return false;
1705 }
1706 
has_oneof_function_helper(const ex & x,const std::map<unsigned,int> & m)1707 static bool has_oneof_function_helper(const ex& x,
1708                 const std::map<unsigned,int>& m)
1709 {
1710 	if (is_exactly_a<function>(x)
1711             and m.find(ex_to<function>(x).get_serial()) != m.end())
1712 		return true;
1713 	for (size_t i=0; i<x.nops(); ++i)
1714 		if (has_oneof_function_helper(x.op(i), m))
1715 			return true;
1716 
1717 	return false;
1718 }
1719 
has_allof_function_helper(const ex & x,std::map<unsigned,int> & m)1720 static void has_allof_function_helper(const ex& x,
1721                 std::map<unsigned,int>& m)
1722 {
1723 	if (is_exactly_a<function>(x)) {
1724                 unsigned ser = ex_to<function>(x).get_serial();
1725                 if (m.find(ser) != m.end())
1726         		m[ser] = 1;
1727         }
1728 	for (size_t i=0; i<x.nops(); ++i)
1729 		has_allof_function_helper(x.op(i), m);
1730 }
1731 
has_function(const ex & x,const std::string & s)1732 bool has_function(const ex& x,
1733                 const std::string& s)
1734 {
1735         std::map<unsigned,int> m;
1736         unsigned ser = 0;
1737         for (const auto & elem : function::registered_functions()) {
1738                 if (s == elem.name)
1739                         m[ser] = 0;
1740                 ++ser;
1741         }
1742         if (m.empty())
1743                 return false;
1744         return has_oneof_function_helper(x, m);
1745 }
1746 
has_function(const ex & x,const std::vector<std::string> & v,bool all)1747 bool has_function(const ex& x,
1748                 const std::vector<std::string>& v,
1749                 bool all)
1750 {
1751         std::map<unsigned,int> m;
1752         for (const auto & s : v) {
1753                 unsigned ser = 0;
1754                 for (const auto & elem : function::registered_functions()) {
1755                         if (s == elem.name)
1756                                 m[ser] = 0;
1757                         ++ser;
1758                 }
1759         }
1760         if (m.empty())
1761                 return false;
1762         if (all) {
1763                 has_allof_function_helper(x, m);
1764                 for (const auto & p : m)
1765                         // TODO: false negative if >1 func with same name
1766                         if (p.second == 0)
1767                                 return false;
1768                 return true;
1769         }
1770         return has_oneof_function_helper(x, m);
1771 }
1772 
1773 } // namespace GiNaC
1774 
1775