1 //===-- lib/Evaluate/complex.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/complex.h"
10 #include "llvm/Support/raw_ostream.h"
11 
12 namespace Fortran::evaluate::value {
13 
14 template <typename R>
Add(const Complex & that,Rounding rounding) const15 ValueWithRealFlags<Complex<R>> Complex<R>::Add(
16     const Complex &that, Rounding rounding) const {
17   RealFlags flags;
18   Part reSum{re_.Add(that.re_, rounding).AccumulateFlags(flags)};
19   Part imSum{im_.Add(that.im_, rounding).AccumulateFlags(flags)};
20   return {Complex{reSum, imSum}, flags};
21 }
22 
23 template <typename R>
Subtract(const Complex & that,Rounding rounding) const24 ValueWithRealFlags<Complex<R>> Complex<R>::Subtract(
25     const Complex &that, Rounding rounding) const {
26   RealFlags flags;
27   Part reDiff{re_.Subtract(that.re_, rounding).AccumulateFlags(flags)};
28   Part imDiff{im_.Subtract(that.im_, rounding).AccumulateFlags(flags)};
29   return {Complex{reDiff, imDiff}, flags};
30 }
31 
32 template <typename R>
Multiply(const Complex & that,Rounding rounding) const33 ValueWithRealFlags<Complex<R>> Complex<R>::Multiply(
34     const Complex &that, Rounding rounding) const {
35   // (a + ib)*(c + id) -> ac - bd + i(ad + bc)
36   RealFlags flags;
37   Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
38   Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
39   Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
40   Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
41   Part acbd{ac.Subtract(bd, rounding).AccumulateFlags(flags)};
42   Part adbc{ad.Add(bc, rounding).AccumulateFlags(flags)};
43   return {Complex{acbd, adbc}, flags};
44 }
45 
46 template <typename R>
Divide(const Complex & that,Rounding rounding) const47 ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
48     const Complex &that, Rounding rounding) const {
49   // (a + ib)/(c + id) -> [(a+ib)*(c-id)] / [(c+id)*(c-id)]
50   //   -> [ac+bd+i(bc-ad)] / (cc+dd)
51   //   -> ((ac+bd)/(cc+dd)) + i((bc-ad)/(cc+dd))
52   // but to avoid overflows, scale by d/c if c>=d, else c/d
53   Part scale; // <= 1.0
54   RealFlags flags;
55   bool cGEd{that.re_.ABS().Compare(that.im_.ABS()) != Relation::Less};
56   if (cGEd) {
57     scale = that.im_.Divide(that.re_, rounding).AccumulateFlags(flags);
58   } else {
59     scale = that.re_.Divide(that.im_, rounding).AccumulateFlags(flags);
60   }
61   Part den;
62   if (cGEd) {
63     Part dS{scale.Multiply(that.im_, rounding).AccumulateFlags(flags)};
64     den = dS.Add(that.re_, rounding).AccumulateFlags(flags);
65   } else {
66     Part cS{scale.Multiply(that.re_, rounding).AccumulateFlags(flags)};
67     den = cS.Add(that.im_, rounding).AccumulateFlags(flags);
68   }
69   Part aS{scale.Multiply(re_, rounding).AccumulateFlags(flags)};
70   Part bS{scale.Multiply(im_, rounding).AccumulateFlags(flags)};
71   Part re1, im1;
72   if (cGEd) {
73     re1 = re_.Add(bS, rounding).AccumulateFlags(flags);
74     im1 = im_.Subtract(aS, rounding).AccumulateFlags(flags);
75   } else {
76     re1 = aS.Add(im_, rounding).AccumulateFlags(flags);
77     im1 = bS.Subtract(re_, rounding).AccumulateFlags(flags);
78   }
79   Part re{re1.Divide(den, rounding).AccumulateFlags(flags)};
80   Part im{im1.Divide(den, rounding).AccumulateFlags(flags)};
81   return {Complex{re, im}, flags};
82 }
83 
DumpHexadecimal() const84 template <typename R> std::string Complex<R>::DumpHexadecimal() const {
85   std::string result{'('};
86   result += re_.DumpHexadecimal();
87   result += ',';
88   result += im_.DumpHexadecimal();
89   result += ')';
90   return result;
91 }
92 
93 template <typename R>
AsFortran(llvm::raw_ostream & o,int kind) const94 llvm::raw_ostream &Complex<R>::AsFortran(llvm::raw_ostream &o, int kind) const {
95   re_.AsFortran(o << '(', kind);
96   im_.AsFortran(o << ',', kind);
97   return o << ')';
98 }
99 
100 template class Complex<Real<Integer<16>, 11>>;
101 template class Complex<Real<Integer<16>, 8>>;
102 template class Complex<Real<Integer<32>, 24>>;
103 template class Complex<Real<Integer<64>, 53>>;
104 template class Complex<Real<Integer<80>, 64>>;
105 template class Complex<Real<Integer<128>, 113>>;
106 } // namespace Fortran::evaluate::value
107