1 /** @file operators.cpp
2 *
3 * Implementation of GiNaC's overloaded operators. */
4
5 /*
6 * GiNaC Copyright (C) 1999-2008 Johannes Gutenberg University Mainz, Germany
7 *
8 * This program is free software; you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation; either version 2 of the License, or
11 * (at your option) any later version.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21 */
22
23 #include "numeric.h"
24 #include "add.h"
25 #include "mul.h"
26 #include "power.h"
27 #include "relational.h"
28 #include "print.h"
29 #include "utils.h"
30
31 #include "operators.h"
32
33 #include <iostream>
34 #include <iomanip>
35
36 namespace GiNaC {
37
38 /** Used internally by operator+() to add two ex objects together. */
exadd(const ex & lh,const ex & rh)39 static inline const ex exadd(const ex & lh, const ex & rh)
40 {
41 if (is_exactly_a<numeric>(lh)
42 and is_exactly_a<numeric>(rh))
43 return ex_to<numeric>(lh).add(ex_to<numeric>(rh));
44 return (new add(lh,rh))->setflag(status_flags::dynallocated);
45 }
46
47 /** Used internally by operator*() to multiply two ex objects together. */
exmul(const ex & lh,const ex & rh)48 static inline const ex exmul(const ex & lh, const ex & rh)
49 {
50 if (is_exactly_a<numeric>(lh)
51 and is_exactly_a<numeric>(rh))
52 return ex_to<numeric>(lh).mul(ex_to<numeric>(rh));
53 return (new mul(lh,rh))->setflag(status_flags::dynallocated);
54 }
55
56 /** Used internally by operator-() and friends to change the sign of an argument. */
exminus(const ex & lh)57 static inline const ex exminus(const ex & lh)
58 {
59 if (is_exactly_a<numeric>(lh))
60 return ex_to<numeric>(lh).negative();
61 return (new mul(lh,_ex_1))->setflag(status_flags::dynallocated);
62 }
63
64 // binary arithmetic operators ex with ex
65
operator +(const ex & lh,const ex & rh)66 const ex operator+(const ex & lh, const ex & rh)
67 {
68 return exadd(lh, rh);
69 }
70
operator -(const ex & lh,const ex & rh)71 const ex operator-(const ex & lh, const ex & rh)
72 {
73 return exadd(lh, exminus(rh));
74 }
75
operator *(const ex & lh,const ex & rh)76 const ex operator*(const ex & lh, const ex & rh)
77 {
78 return exmul(lh, rh);
79 }
80
operator /(const ex & lh,const ex & rh)81 const ex operator/(const ex & lh, const ex & rh)
82 {
83 return exmul(lh, power(rh,_ex_1));
84 }
85
86
87 // binary arithmetic operators numeric with numeric
88
operator +(const numeric & lh,const numeric & rh)89 const numeric operator+(const numeric & lh, const numeric & rh)
90 {
91 return lh.add(rh);
92 }
93
operator -(const numeric & lh,const numeric & rh)94 const numeric operator-(const numeric & lh, const numeric & rh)
95 {
96 return lh.sub(rh);
97 }
98
operator *(const numeric & lh,const numeric & rh)99 const numeric operator*(const numeric & lh, const numeric & rh)
100 {
101 return lh.mul(rh);
102 }
103
operator /(const numeric & lh,const numeric & rh)104 const numeric operator/(const numeric & lh, const numeric & rh)
105 {
106 return lh.div(rh);
107 }
108
109
110 // binary arithmetic assignment operators with ex
111
operator +=(ex & lh,const ex & rh)112 ex & operator+=(ex & lh, const ex & rh)
113 {
114 return lh = exadd(lh, rh);
115 }
116
operator -=(ex & lh,const ex & rh)117 ex & operator-=(ex & lh, const ex & rh)
118 {
119 return lh = exadd(lh, exminus(rh));
120 }
121
operator *=(ex & lh,const ex & rh)122 ex & operator*=(ex & lh, const ex & rh)
123 {
124 return lh = exmul(lh, rh);
125 }
126
operator /=(ex & lh,const ex & rh)127 ex & operator/=(ex & lh, const ex & rh)
128 {
129 return lh = exmul(lh, power(rh,_ex_1));
130 }
131
132
133 // unary operators
134
operator +(const ex & lh)135 const ex operator+(const ex & lh)
136 {
137 return lh;
138 }
139
operator -(const ex & lh)140 const ex operator-(const ex & lh)
141 {
142 return exminus(lh);
143 }
144
operator +(const numeric & lh)145 const numeric operator+(const numeric & lh)
146 {
147 return lh;
148 }
149
operator -(const numeric & lh)150 const numeric operator-(const numeric & lh)
151 {
152 return lh.negative(); // better than _num_1_p->mul(lh)
153 }
154
155
156 // increment / decrement operators
157
158 /** Expression prefix increment. Adds 1 and returns incremented ex. */
operator ++(ex & rh)159 ex & operator++(ex & rh)
160 {
161 if (is_exactly_a<numeric>(rh)) {
162 rh = numeric(ex_to<numeric>(rh) + *_num1_p);
163 return rh;
164 }
165 return rh = exadd(rh, _ex1);
166 }
167
168 /** Expression prefix decrement. Subtracts 1 and returns decremented ex. */
operator --(ex & rh)169 ex & operator--(ex & rh)
170 {
171 if (is_exactly_a<numeric>(rh)) {
172 rh = numeric(ex_to<numeric>(rh) + *_num_1_p);
173 return rh;
174 }
175 return rh = exadd(rh, _ex_1);
176 }
177
178 /** Expression postfix increment. Returns the ex and leaves the original
179 * incremented by 1. */
operator ++(ex & lh,int)180 const ex operator++(ex & lh, int)
181 {
182 ex tmp(lh);
183 lh = exadd(lh, _ex1);
184 return tmp;
185 }
186
187 /** Expression postfix decrement. Returns the ex and leaves the original
188 * decremented by 1. */
operator --(ex & lh,int)189 const ex operator--(ex & lh, int)
190 {
191 ex tmp(lh);
192 lh = exadd(lh, _ex_1);
193 return tmp;
194 }
195
196 /** Numeric prefix increment. Adds 1 and returns incremented number. */
operator ++(numeric & rh)197 numeric& operator++(numeric & rh)
198 {
199 rh = rh.add(*_num1_p);
200 return rh;
201 }
202
203 /** Numeric prefix decrement. Subtracts 1 and returns decremented number. */
operator --(numeric & rh)204 numeric& operator--(numeric & rh)
205 {
206 rh = rh.add(*_num_1_p);
207 return rh;
208 }
209
210 /** Numeric postfix increment. Returns the number and leaves the original
211 * incremented by 1. */
operator ++(numeric & lh,int)212 const numeric operator++(numeric & lh, int)
213 {
214 numeric tmp(lh);
215 lh = lh.add(*_num1_p);
216 return tmp;
217 }
218
219 /** Numeric postfix decrement. Returns the number and leaves the original
220 * decremented by 1. */
operator --(numeric & lh,int)221 const numeric operator--(numeric & lh, int)
222 {
223 numeric tmp(lh);
224 lh = lh.add(*_num_1_p);
225 return tmp;
226 }
227
228 // binary relational operators ex with ex
229
operator ==(const ex & lh,const ex & rh)230 const relational operator==(const ex & lh, const ex & rh)
231 {
232 return relational(lh,rh,relational::equal);
233 }
234
operator !=(const ex & lh,const ex & rh)235 const relational operator!=(const ex & lh, const ex & rh)
236 {
237 return relational(lh,rh,relational::not_equal);
238 }
239
operator <(const ex & lh,const ex & rh)240 const relational operator<(const ex & lh, const ex & rh)
241 {
242 return relational(lh,rh,relational::less);
243 }
244
operator <=(const ex & lh,const ex & rh)245 const relational operator<=(const ex & lh, const ex & rh)
246 {
247 return relational(lh,rh,relational::less_or_equal);
248 }
249
operator >(const ex & lh,const ex & rh)250 const relational operator>(const ex & lh, const ex & rh)
251 {
252 return relational(lh,rh,relational::greater);
253 }
254
operator >=(const ex & lh,const ex & rh)255 const relational operator>=(const ex & lh, const ex & rh)
256 {
257 return relational(lh,rh,relational::greater_or_equal);
258 }
259
260 // input/output stream operators and manipulators
261
my_ios_index()262 static int my_ios_index()
263 {
264 static int i = std::ios_base::xalloc();
265 return i;
266 }
267
268 // Stream format gets copied or destroyed
my_ios_callback(std::ios_base::event ev,std::ios_base & s,int i)269 static void my_ios_callback(std::ios_base::event ev, std::ios_base & s, int i)
270 {
271 std::unique_ptr<print_context> p(static_cast<print_context *>(s.pword(i)));
272 if (ev == std::ios_base::erase_event) {
273 s.pword(i) = nullptr;
274 } else if (ev == std::ios_base::copyfmt_event && p != nullptr)
275 s.pword(i) = p->duplicate();
276 }
277
278 enum {
279 callback_registered = 1
280 };
281
282 // Get print_context associated with stream, may return 0 if no context has
283 // been associated yet
get_print_context(std::ios_base & s)284 static inline print_context *get_print_context(std::ios_base & s)
285 {
286 return static_cast<print_context *>(s.pword(my_ios_index()));
287 }
288
289 // Set print_context associated with stream, retain options
set_print_context(std::ios_base & s,const print_context & c)290 static void set_print_context(std::ios_base & s, const print_context & c)
291 {
292 int i = my_ios_index();
293 long flags = s.iword(i);
294 if ((flags & callback_registered) == 0) {
295 s.register_callback(my_ios_callback, i);
296 s.iword(i) = flags | callback_registered;
297 }
298 print_context *p = static_cast<print_context *>(s.pword(i));
299 unsigned options = p != nullptr ? p->options : c.options;
300 delete p;
301 p = c.duplicate();
302 p->options = options;
303 s.pword(i) = p;
304 }
305
306 // Get options for print_context associated with stream
get_print_options(std::ios_base & s)307 static inline unsigned get_print_options(std::ios_base & s)
308 {
309 print_context *p = get_print_context(s);
310 return p != nullptr ? p->options : 0;
311 }
312
313 // Set options for print_context associated with stream
set_print_options(std::ostream & s,unsigned options)314 static void set_print_options(std::ostream & s, unsigned options)
315 {
316 print_context *p = get_print_context(s);
317 if (p == nullptr)
318 set_print_context(s, print_dflt(s, options));
319 else
320 p->options = options;
321 }
322
operator <<(std::ostream & os,const ex & e)323 std::ostream & operator<<(std::ostream & os, const ex & e)
324 {
325 print_context *p = get_print_context(os);
326 if (p == nullptr)
327 e.print(print_dflt(os));
328 else
329 e.print(*p);
330 return os;
331 }
332
operator <<(std::ostream & os,const exvector & e)333 std::ostream & operator<<(std::ostream & os, const exvector & e)
334 {
335 print_context *p = get_print_context(os);
336 auto i = e.begin();
337 auto vend = e.end();
338
339 if (i==vend) {
340 os << "[]";
341 return os;
342 }
343
344 os << "[";
345 while (true) {
346 if (p == nullptr)
347 i -> print(print_dflt(os));
348 else
349 i -> print(*p);
350 ++i;
351 if (i==vend)
352 break;
353 os << ",";
354 }
355 os << "]";
356
357 return os;
358 }
359
operator <<(std::ostream & os,const exset & e)360 std::ostream & operator<<(std::ostream & os, const exset & e)
361 {
362 print_context *p = get_print_context(os);
363 auto i = e.begin();
364 auto send = e.end();
365
366 if (i==send) {
367 os << "<>";
368 return os;
369 }
370
371 os << "<";
372 while (true) {
373 if (p == nullptr)
374 i->print(print_dflt(os));
375 else
376 i->print(*p);
377 ++i;
378 if (i == send)
379 break;
380 os << ",";
381 }
382 os << ">";
383
384 return os;
385 }
386
operator <<(std::ostream & os,const exmap & e)387 std::ostream & operator<<(std::ostream & os, const exmap & e)
388 {
389 print_context *p = get_print_context(os);
390 auto i = e.begin();
391 auto mend = e.end();
392
393 if (i==mend) {
394 os << "{}";
395 return os;
396 }
397
398 os << "{";
399 while (true) {
400 if (p == nullptr)
401 i->first.print(print_dflt(os));
402 else
403 i->first.print(*p);
404 os << "==";
405 if (p == nullptr)
406 i->second.print(print_dflt(os));
407 else
408 i->second.print(*p);
409 ++i;
410 if( i==mend )
411 break;
412 os << ",";
413 }
414 os << "}";
415
416 return os;
417 }
418
operator >>(std::istream & is,ex & e)419 std::istream & operator>>(std::istream & is, ex & e)
420 {
421 throw (std::logic_error("expression input from streams not implemented"));
422 }
423
dflt(std::ostream & os)424 std::ostream & dflt(std::ostream & os)
425 {
426 set_print_context(os, print_dflt(os));
427 set_print_options(os, 0);
428 return os;
429 }
430
latex(std::ostream & os)431 std::ostream & latex(std::ostream & os)
432 {
433 set_print_context(os, print_latex(os));
434 return os;
435 }
436
python(std::ostream & os)437 std::ostream & python(std::ostream & os)
438 {
439 set_print_context(os, print_python(os));
440 return os;
441 }
442
python_repr(std::ostream & os)443 std::ostream & python_repr(std::ostream & os)
444 {
445 set_print_context(os, print_python_repr(os));
446 return os;
447 }
448
tree(std::ostream & os)449 std::ostream & tree(std::ostream & os)
450 {
451 set_print_context(os, print_tree(os));
452 return os;
453 }
454
index_dimensions(std::ostream & os)455 std::ostream & index_dimensions(std::ostream & os)
456 {
457 set_print_options(os, get_print_options(os) | print_options::print_index_dimensions);
458 return os;
459 }
460
no_index_dimensions(std::ostream & os)461 std::ostream & no_index_dimensions(std::ostream & os)
462 {
463 set_print_options(os, get_print_options(os) & ~print_options::print_index_dimensions);
464 return os;
465 }
466
467 } // namespace GiNaC
468