1 /** @file operators.cpp
2  *
3  *  Implementation of GiNaC's overloaded operators. */
4 
5 /*
6  *  GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
21  */
22 
23 #include "numeric.h"
24 #include "add.h"
25 #include "mul.h"
26 #include "power.h"
27 #include "relational.h"
28 #include "print.h"
29 #include "utils.h"
30 
31 #include "operators.h"
32 
33 #include <iostream>
34 #include <iomanip>
35 
36 namespace GiNaC {
37 
38 /** Used internally by operator+() to add two ex objects together. */
exadd(const ex & lh,const ex & rh)39 static inline const ex exadd(const ex & lh, const ex & rh)
40 {
41         if (is_exactly_a<numeric>(lh)
42             and is_exactly_a<numeric>(rh))
43                 return ex_to<numeric>(lh).add(ex_to<numeric>(rh));
44 	return (new add(lh,rh))->setflag(status_flags::dynallocated);
45 }
46 
47 /** Used internally by operator*() to multiply two ex objects together. */
exmul(const ex & lh,const ex & rh)48 static inline const ex exmul(const ex & lh, const ex & rh)
49 {
50         if (is_exactly_a<numeric>(lh)
51             and is_exactly_a<numeric>(rh))
52                 return ex_to<numeric>(lh).mul(ex_to<numeric>(rh));
53         return (new mul(lh,rh))->setflag(status_flags::dynallocated);
54 }
55 
56 /** Used internally by operator-() and friends to change the sign of an argument. */
exminus(const ex & lh)57 static inline const ex exminus(const ex & lh)
58 {
59         if (is_exactly_a<numeric>(lh))
60                 return ex_to<numeric>(lh).negative();
61 	return (new mul(lh,_ex_1))->setflag(status_flags::dynallocated);
62 }
63 
64 // binary arithmetic operators ex with ex
65 
operator +(const ex & lh,const ex & rh)66 const ex operator+(const ex & lh, const ex & rh)
67 {
68 	return exadd(lh, rh);
69 }
70 
operator -(const ex & lh,const ex & rh)71 const ex operator-(const ex & lh, const ex & rh)
72 {
73 	return exadd(lh, exminus(rh));
74 }
75 
operator *(const ex & lh,const ex & rh)76 const ex operator*(const ex & lh, const ex & rh)
77 {
78 	return exmul(lh, rh);
79 }
80 
operator /(const ex & lh,const ex & rh)81 const ex operator/(const ex & lh, const ex & rh)
82 {
83 	return exmul(lh, power(rh,_ex_1));
84 }
85 
86 
87 // binary arithmetic operators numeric with numeric
88 
operator +(const numeric & lh,const numeric & rh)89 const numeric operator+(const numeric & lh, const numeric & rh)
90 {
91 	return lh.add(rh);
92 }
93 
operator -(const numeric & lh,const numeric & rh)94 const numeric operator-(const numeric & lh, const numeric & rh)
95 {
96 	return lh.sub(rh);
97 }
98 
operator *(const numeric & lh,const numeric & rh)99 const numeric operator*(const numeric & lh, const numeric & rh)
100 {
101 	return lh.mul(rh);
102 }
103 
operator /(const numeric & lh,const numeric & rh)104 const numeric operator/(const numeric & lh, const numeric & rh)
105 {
106 	return lh.div(rh);
107 }
108 
109 
110 // binary arithmetic assignment operators with ex
111 
operator +=(ex & lh,const ex & rh)112 ex & operator+=(ex & lh, const ex & rh)
113 {
114 	return lh = exadd(lh, rh);
115 }
116 
operator -=(ex & lh,const ex & rh)117 ex & operator-=(ex & lh, const ex & rh)
118 {
119 	return lh = exadd(lh, exminus(rh));
120 }
121 
operator *=(ex & lh,const ex & rh)122 ex & operator*=(ex & lh, const ex & rh)
123 {
124 	return lh = exmul(lh, rh);
125 }
126 
operator /=(ex & lh,const ex & rh)127 ex & operator/=(ex & lh, const ex & rh)
128 {
129 	return lh = exmul(lh, power(rh,_ex_1));
130 }
131 
132 
133 // unary operators
134 
operator +(const ex & lh)135 const ex operator+(const ex & lh)
136 {
137 	return lh;
138 }
139 
operator -(const ex & lh)140 const ex operator-(const ex & lh)
141 {
142 	return exminus(lh);
143 }
144 
operator +(const numeric & lh)145 const numeric operator+(const numeric & lh)
146 {
147 	return lh;
148 }
149 
operator -(const numeric & lh)150 const numeric operator-(const numeric & lh)
151 {
152 	return lh.negative(); // better than _num_1_p->mul(lh)
153 }
154 
155 
156 // increment / decrement operators
157 
158 /** Expression prefix increment.  Adds 1 and returns incremented ex. */
operator ++(ex & rh)159 ex & operator++(ex & rh)
160 {
161         if (is_exactly_a<numeric>(rh)) {
162                 rh = numeric(ex_to<numeric>(rh) + *_num1_p);
163                 return rh;
164         }
165 	return rh = exadd(rh, _ex1);
166 }
167 
168 /** Expression prefix decrement.  Subtracts 1 and returns decremented ex. */
operator --(ex & rh)169 ex & operator--(ex & rh)
170 {
171         if (is_exactly_a<numeric>(rh)) {
172                 rh = numeric(ex_to<numeric>(rh) + *_num_1_p);
173                 return rh;
174         }
175 	return rh = exadd(rh, _ex_1);
176 }
177 
178 /** Expression postfix increment.  Returns the ex and leaves the original
179  *  incremented by 1. */
operator ++(ex & lh,int)180 const ex operator++(ex & lh, int)
181 {
182 	ex tmp(lh);
183 	lh = exadd(lh, _ex1);
184 	return tmp;
185 }
186 
187 /** Expression postfix decrement.  Returns the ex and leaves the original
188  *  decremented by 1. */
operator --(ex & lh,int)189 const ex operator--(ex & lh, int)
190 {
191 	ex tmp(lh);
192 	lh = exadd(lh, _ex_1);
193 	return tmp;
194 }
195 
196 /** Numeric prefix increment.  Adds 1 and returns incremented number. */
operator ++(numeric & rh)197 numeric& operator++(numeric & rh)
198 {
199 	rh = rh.add(*_num1_p);
200 	return rh;
201 }
202 
203 /** Numeric prefix decrement.  Subtracts 1 and returns decremented number. */
operator --(numeric & rh)204 numeric& operator--(numeric & rh)
205 {
206 	rh = rh.add(*_num_1_p);
207 	return rh;
208 }
209 
210 /** Numeric postfix increment.  Returns the number and leaves the original
211  *  incremented by 1. */
operator ++(numeric & lh,int)212 const numeric operator++(numeric & lh, int)
213 {
214 	numeric tmp(lh);
215 	lh = lh.add(*_num1_p);
216 	return tmp;
217 }
218 
219 /** Numeric postfix decrement.  Returns the number and leaves the original
220  *  decremented by 1. */
operator --(numeric & lh,int)221 const numeric operator--(numeric & lh, int)
222 {
223 	numeric tmp(lh);
224 	lh = lh.add(*_num_1_p);
225 	return tmp;
226 }
227 
228 // binary relational operators ex with ex
229 
operator ==(const ex & lh,const ex & rh)230 const relational operator==(const ex & lh, const ex & rh)
231 {
232 	return relational(lh,rh,relational::equal);
233 }
234 
operator !=(const ex & lh,const ex & rh)235 const relational operator!=(const ex & lh, const ex & rh)
236 {
237 	return relational(lh,rh,relational::not_equal);
238 }
239 
operator <(const ex & lh,const ex & rh)240 const relational operator<(const ex & lh, const ex & rh)
241 {
242 	return relational(lh,rh,relational::less);
243 }
244 
operator <=(const ex & lh,const ex & rh)245 const relational operator<=(const ex & lh, const ex & rh)
246 {
247 	return relational(lh,rh,relational::less_or_equal);
248 }
249 
operator >(const ex & lh,const ex & rh)250 const relational operator>(const ex & lh, const ex & rh)
251 {
252 	return relational(lh,rh,relational::greater);
253 }
254 
operator >=(const ex & lh,const ex & rh)255 const relational operator>=(const ex & lh, const ex & rh)
256 {
257 	return relational(lh,rh,relational::greater_or_equal);
258 }
259 
260 // input/output stream operators and manipulators
261 
my_ios_index()262 static int my_ios_index()
263 {
264 	static int i = std::ios_base::xalloc();
265 	return i;
266 }
267 
268 // Stream format gets copied or destroyed
my_ios_callback(std::ios_base::event ev,std::ios_base & s,int i)269 static void my_ios_callback(std::ios_base::event ev, std::ios_base & s, int i)
270 {
271 	std::unique_ptr<print_context> p(static_cast<print_context *>(s.pword(i)));
272 	if (ev == std::ios_base::erase_event) {
273 		s.pword(i) = nullptr;
274 	} else if (ev == std::ios_base::copyfmt_event && p != nullptr)
275 		s.pword(i) = p->duplicate();
276 }
277 
278 enum {
279 	callback_registered = 1
280 };
281 
282 // Get print_context associated with stream, may return 0 if no context has
283 // been associated yet
get_print_context(std::ios_base & s)284 static inline print_context *get_print_context(std::ios_base & s)
285 {
286 	return static_cast<print_context *>(s.pword(my_ios_index()));
287 }
288 
289 // Set print_context associated with stream, retain options
set_print_context(std::ios_base & s,const print_context & c)290 static void set_print_context(std::ios_base & s, const print_context & c)
291 {
292 	int i = my_ios_index();
293 	long flags = s.iword(i);
294 	if ((flags & callback_registered) == 0) {
295 		s.register_callback(my_ios_callback, i);
296 		s.iword(i) = flags | callback_registered;
297 	}
298 	print_context *p = static_cast<print_context *>(s.pword(i));
299 	unsigned options = p != nullptr ? p->options : c.options;
300 	delete p;
301 	p = c.duplicate();
302 	p->options = options;
303 	s.pword(i) = p;
304 }
305 
306 // Get options for print_context associated with stream
get_print_options(std::ios_base & s)307 static inline unsigned get_print_options(std::ios_base & s)
308 {
309 	print_context *p = get_print_context(s);
310 	return p != nullptr ? p->options : 0;
311 }
312 
313 // Set options for print_context associated with stream
set_print_options(std::ostream & s,unsigned options)314 static void set_print_options(std::ostream & s, unsigned options)
315 {
316 	print_context *p = get_print_context(s);
317 	if (p == nullptr)
318 		set_print_context(s, print_dflt(s, options));
319 	else
320 		p->options = options;
321 }
322 
operator <<(std::ostream & os,const ex & e)323 std::ostream & operator<<(std::ostream & os, const ex & e)
324 {
325 	print_context *p = get_print_context(os);
326 	if (p == nullptr)
327 		e.print(print_dflt(os));
328 	else
329 		e.print(*p);
330 	return os;
331 }
332 
operator <<(std::ostream & os,const exvector & e)333 std::ostream & operator<<(std::ostream & os, const exvector & e)
334 {
335 	print_context *p = get_print_context(os);
336 	auto i = e.begin();
337 	auto vend = e.end();
338 
339 	if (i==vend) {
340 		os << "[]";
341 		return os;
342 	}
343 
344 	os << "[";
345 	while (true) {
346 		if (p == nullptr)
347 			i -> print(print_dflt(os));
348 		else
349 			i -> print(*p);
350 		++i;
351 		if (i==vend)
352 			break;
353 		os << ",";
354 	}
355 	os << "]";
356 
357 	return os;
358 }
359 
operator <<(std::ostream & os,const exset & e)360 std::ostream & operator<<(std::ostream & os, const exset & e)
361 {
362 	print_context *p = get_print_context(os);
363 	auto i = e.begin();
364 	auto send = e.end();
365 
366 	if (i==send) {
367 		os << "<>";
368 		return os;
369 	}
370 
371 	os << "<";
372 	while (true) {
373 		if (p == nullptr)
374 			i->print(print_dflt(os));
375 		else
376 			i->print(*p);
377 		++i;
378 		if (i == send)
379 			break;
380 		os << ",";
381 	}
382 	os << ">";
383 
384 	return os;
385 }
386 
operator <<(std::ostream & os,const exmap & e)387 std::ostream & operator<<(std::ostream & os, const exmap & e)
388 {
389 	print_context *p = get_print_context(os);
390 	auto i = e.begin();
391 	auto mend = e.end();
392 
393 	if (i==mend) {
394 		os << "{}";
395 		return os;
396 	}
397 
398 	os << "{";
399 	while (true) {
400 		if (p == nullptr)
401 			i->first.print(print_dflt(os));
402 		else
403 			i->first.print(*p);
404 		os << "==";
405 		if (p == nullptr)
406 			i->second.print(print_dflt(os));
407 		else
408 			i->second.print(*p);
409 		++i;
410 		if( i==mend )
411 			break;
412 		os << ",";
413 	}
414 	os << "}";
415 
416 	return os;
417 }
418 
operator >>(std::istream & is,ex & e)419 std::istream & operator>>(std::istream & is, ex & e)
420 {
421 	throw (std::logic_error("expression input from streams not implemented"));
422 }
423 
dflt(std::ostream & os)424 std::ostream & dflt(std::ostream & os)
425 {
426 	set_print_context(os, print_dflt(os));
427 	set_print_options(os, 0);
428 	return os;
429 }
430 
latex(std::ostream & os)431 std::ostream & latex(std::ostream & os)
432 {
433 	set_print_context(os, print_latex(os));
434 	return os;
435 }
436 
python(std::ostream & os)437 std::ostream & python(std::ostream & os)
438 {
439 	set_print_context(os, print_python(os));
440 	return os;
441 }
442 
python_repr(std::ostream & os)443 std::ostream & python_repr(std::ostream & os)
444 {
445 	set_print_context(os, print_python_repr(os));
446 	return os;
447 }
448 
tree(std::ostream & os)449 std::ostream & tree(std::ostream & os)
450 {
451 	set_print_context(os, print_tree(os));
452 	return os;
453 }
454 
index_dimensions(std::ostream & os)455 std::ostream & index_dimensions(std::ostream & os)
456 {
457 	set_print_options(os, get_print_options(os) | print_options::print_index_dimensions);
458 	return os;
459 }
460 
no_index_dimensions(std::ostream & os)461 std::ostream & no_index_dimensions(std::ostream & os)
462 {
463 	set_print_options(os, get_print_options(os) & ~print_options::print_index_dimensions);
464 	return os;
465 }
466 
467 } // namespace GiNaC
468