1 #ifndef CXX_MPZ_HPP_
2 #define CXX_MPZ_HPP_
3
4 #include "macros.h"
5
6 #include <gmp.h>
7 #include <istream>
8 #include <ostream>
9 #include <limits>
10 #include <type_traits>
11 #include <stdlib.h>
12 #include "gmp_aux.h"
13 #include "gmp_auxx.hpp"
14
15 struct cxx_mpz {
16 public:
17 typedef mp_limb_t WordType;
18 mpz_t x;
cxx_mpzcxx_mpz19 cxx_mpz() { mpz_init(x); }
20 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
cxx_mpzcxx_mpz21 cxx_mpz (const T & rhs) {
22 gmp_auxx::mpz_init_set(x, rhs);
23 }
24
~cxx_mpzcxx_mpz25 ~cxx_mpz() { mpz_clear(x); }
cxx_mpzcxx_mpz26 cxx_mpz(cxx_mpz const & o) {
27 mpz_init_set(x, o.x);
28 }
operator =cxx_mpz29 cxx_mpz & operator=(cxx_mpz const & o) {
30 mpz_set(x, o.x);
31 return *this;
32 }
33 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator =cxx_mpz34 cxx_mpz & operator=(const T a) {
35 gmp_auxx::mpz_set(x, a);
36 return *this;
37 }
38
39 #if __cplusplus >= 201103L
cxx_mpzcxx_mpz40 cxx_mpz(cxx_mpz && o) {
41 mpz_init(x);
42 mpz_swap(x, o.x);
43 }
operator =cxx_mpz44 cxx_mpz& operator=(cxx_mpz && o) {
45 mpz_swap(x, o.x);
46 return *this;
47 }
48 #endif
operator mpz_ptrcxx_mpz49 operator mpz_ptr() { return x; }
operator mpz_srcptrcxx_mpz50 operator mpz_srcptr() const { return x; }
operator ->cxx_mpz51 mpz_ptr operator->() { return x; }
operator ->cxx_mpz52 mpz_srcptr operator->() const { return x; }
operator uint64_tcxx_mpz53 explicit operator uint64_t() const {return mpz_get_uint64(x);}
54
55 /** Set the value of the cxx_mpz to that of the uint64_t array s,
56 * least significant word first.
57 */
setcxx_mpz58 bool set(const uint64_t *s, const size_t len) {
59 mpz_import(x, len, -1, sizeof(uint64_t), 0, 0, s);
60 return true;
61 }
62
63 /** Return the size in uint64_ts that is required in the output for
64 * get(uint64_t *, size_t) */
sizecxx_mpz65 size_t size() const {return iceildiv(mpz_sizeinbase(x, 2), 64);}
66
67 /** Write the absolute value of the cxx_mpz to r.
68 * The least significant word is written first. Exactly len words are
69 * written. If len is less than the required size as given by size(),
70 * output is truncated. If len is greater, output is padded with zeroes.
71 */
getcxx_mpz72 void get(uint64_t *r, const size_t len) const {
73 const bool useTemp = len < size();
74 size_t written;
75 uint64_t *t = (uint64_t *) mpz_export(useTemp ? NULL : r, &written,
76 -1, sizeof(uint64_t), 0, 0, x);
77 if (useTemp) {
78 /* Here, len < written. Write only the len least significant words
79 * to r */
80 for (size_t i = 0; i < len; i++)
81 r[i] = t[i];
82 free(t);
83 } else {
84 ASSERT_ALWAYS(written <= len);
85 for (size_t i = written; i < len; i++)
86 r[i] = 0;
87 }
88 }
89
90 /* Should use a C++ iterator instead? Would that be slower? */
getWordSizecxx_mpz91 static int getWordSize() {return GMP_NUMB_BITS;}
getWordCountcxx_mpz92 size_t getWordCount() const {return mpz_size(x);}
getWordcxx_mpz93 WordType getWord(const size_t i) const {return mpz_getlimbn(x, i);}
94
95 template <typename T>
fitscxx_mpz96 bool fits() const {
97 return gmp_auxx::mpz_fits<T>(x);
98 }
99 };
100
101 template <>
fits() const102 inline bool cxx_mpz::fits<cxx_mpz>() const {
103 return true;
104 }
105
106 #if GNUC_VERSION_ATLEAST(4,3,0)
107 extern void mpz_init(cxx_mpz & pl) __attribute__((error("mpz_init must not be called on a mpz reference -- it is the caller's business (via a ctor)")));
108 extern void mpz_clear(cxx_mpz & pl) __attribute__((error("mpz_clear must not be called on a mpz reference -- it is the caller's business (via a dtor)")));
109 #endif
110
111 struct cxx_mpq{
112 mpq_t x;
cxx_mpqcxx_mpq113 cxx_mpq() {mpq_init(x);}
~cxx_mpqcxx_mpq114 ~cxx_mpq() {mpq_clear(x);}
cxx_mpqcxx_mpq115 cxx_mpq(unsigned long a, unsigned long b = 1) { mpq_init(x); mpq_set_ui(x, a,b); }
cxx_mpqcxx_mpq116 cxx_mpq(cxx_mpq const & o) {
117 mpq_init(x);
118 mpq_set(x, o.x);
119 }
operator =cxx_mpq120 cxx_mpq & operator=(cxx_mpq const & o) {
121 mpq_set(x, o.x);
122 return *this;
123 }
124 #if __cplusplus >= 201103L
cxx_mpqcxx_mpq125 cxx_mpq(cxx_mpq && o) {
126 mpq_init(x);
127 mpq_swap(x, o.x);
128 }
operator =cxx_mpq129 cxx_mpq& operator=(cxx_mpq && o) {
130 mpq_swap(x, o.x);
131 return *this;
132 }
133 #endif
operator mpq_ptrcxx_mpq134 operator mpq_ptr() { return x; }
operator mpq_srcptrcxx_mpq135 operator mpq_srcptr() const { return x; }
operator ->cxx_mpq136 mpq_ptr operator->() { return x; }
operator ->cxx_mpq137 mpq_srcptr operator->() const { return x; }
138 };
139 #if GNUC_VERSION_ATLEAST(4,3,0)
140 extern void mpq_init(cxx_mpq & pl) __attribute__((error("mpq_init must not be called on a mpq reference -- it is the caller's business (via a ctor)")));
141 extern void mpq_clear(cxx_mpq & pl) __attribute__((error("mpq_clear must not be called on a mpq reference -- it is the caller's business (via a dtor)")));
142 #endif
143
144 #define CXX_MPZ_DEFINE_CMP(OP) \
145 inline bool operator OP(cxx_mpz const & a, cxx_mpz const & b) { return mpz_cmp(a, b) OP 0; } \
146 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 > \
147 inline bool operator OP(cxx_mpz const & a, const T b) { return gmp_auxx::mpz_cmp(a, b) OP 0; } \
148 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 > \
149 inline bool operator OP(const T a, cxx_mpz const & b) { return 0 OP gmp_auxx::mpz_cmp(b, a); }
150
151 CXX_MPZ_DEFINE_CMP(==)
152 CXX_MPZ_DEFINE_CMP(!=)
153 CXX_MPZ_DEFINE_CMP(<)
154 CXX_MPZ_DEFINE_CMP(>)
155 CXX_MPZ_DEFINE_CMP(<=)
156 CXX_MPZ_DEFINE_CMP(>=)
157
operator +(cxx_mpz const & a,cxx_mpz const & b)158 inline cxx_mpz operator+(cxx_mpz const & a, cxx_mpz const & b) { cxx_mpz r; mpz_add(r, a, b); return r; }
159 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator +(cxx_mpz const & a,const T b)160 inline cxx_mpz operator+(cxx_mpz const & a, const T b) { cxx_mpz r; gmp_auxx::mpz_add(r, a, b); return r; }
161 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator +(const T a,cxx_mpz const & b)162 inline cxx_mpz operator+(const T a, cxx_mpz const & b) { cxx_mpz r; gmp_auxx::mpz_add(r, b, a); return r; }
163
operator +=(cxx_mpz & a,cxx_mpz const & b)164 inline cxx_mpz & operator+=(cxx_mpz & a, cxx_mpz const & b) { mpz_add(a, a, b); return a; }
165 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator +=(cxx_mpz & a,const T b)166 inline cxx_mpz & operator+=(cxx_mpz & a, const T b) { gmp_auxx::mpz_add(a, a, b); return a; }
167
operator -(cxx_mpz const & a,cxx_mpz const & b)168 inline cxx_mpz operator-(cxx_mpz const & a, cxx_mpz const & b) { cxx_mpz r; mpz_sub(r, a, b); return r; }
169 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator -(cxx_mpz const & a,const T b)170 inline cxx_mpz operator-(cxx_mpz const & a, const T b) { cxx_mpz r; gmp_auxx::mpz_sub(r, a, b); return r; }
171 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator -(const T a,cxx_mpz const & b)172 inline cxx_mpz operator-(const T a, cxx_mpz const & b) { cxx_mpz r; gmp_auxx::mpz_sub(r, a, b); return r; }
173
operator -=(cxx_mpz & a,cxx_mpz const & b)174 inline cxx_mpz & operator-=(cxx_mpz & a, cxx_mpz const & b) { mpz_sub(a, a, b); return a; }
175 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator -=(cxx_mpz & a,const T b)176 inline cxx_mpz & operator-=(cxx_mpz & a, const T b) { gmp_auxx::mpz_sub(a, a, b); return a; }
177
operator *(cxx_mpz const & a,cxx_mpz const & b)178 inline cxx_mpz operator*(cxx_mpz const & a, cxx_mpz const & b) { cxx_mpz r; mpz_mul(r, a, b); return r; }
179 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator *(cxx_mpz const & a,const T b)180 inline cxx_mpz operator*(cxx_mpz const & a, const T b) { cxx_mpz r; gmp_auxx::mpz_mul(r, a, b); return r; }
181 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator *(const T a,cxx_mpz const & b)182 inline cxx_mpz operator*(const T a, cxx_mpz const & b) { cxx_mpz r; gmp_auxx::mpz_mul(r, b, a); return r; }
183
operator *=(cxx_mpz & a,cxx_mpz const & b)184 inline cxx_mpz & operator*=(cxx_mpz & a, cxx_mpz const & b) { mpz_mul(a, a, b); return a; }
185 template <typename T, std::enable_if_t<std::is_integral<T>::value, int> = 0 >
operator *=(cxx_mpz & a,const T b)186 inline cxx_mpz & operator*=(cxx_mpz & a, const T b) { gmp_auxx::mpz_mul(a, a, b); return a; }
187
operator /(cxx_mpz const & a,cxx_mpz const & b)188 inline cxx_mpz operator/(cxx_mpz const & a, cxx_mpz const & b) { cxx_mpz r; mpz_tdiv_q(r, a, b); return r; }
189 template <typename T, std::enable_if_t<std::is_integral<T>::value && std::is_unsigned<T>::value, int> = 0 >
operator /(cxx_mpz const & a,const T b)190 inline cxx_mpz operator/(cxx_mpz const & a, const T b) { cxx_mpz r; mpz_tdiv_q_uint64(r, a, b); return r; }
191
operator /=(cxx_mpz & a,cxx_mpz const & b)192 inline cxx_mpz & operator/=(cxx_mpz & a, cxx_mpz const & b) { mpz_tdiv_q(a, a, b); return a; }
193 template <typename T, std::enable_if_t<std::is_integral<T>::value && std::is_unsigned<T>::value, int> = 0 >
operator /=(cxx_mpz & a,const T b)194 inline cxx_mpz & operator/=(cxx_mpz & a, const T b) { mpz_tdiv_q_uint64(a, a, b); return a; }
195
operator %(cxx_mpz const & a,cxx_mpz const & b)196 inline cxx_mpz operator%(cxx_mpz const & a, cxx_mpz const & b) { cxx_mpz r; mpz_tdiv_r(r, a, b); return r; }
197 template <typename T, std::enable_if_t<std::is_integral<T>::value && std::is_unsigned<T>::value, int> = 0 >
operator %(cxx_mpz const & a,const T b)198 inline cxx_mpz operator%(cxx_mpz const & a, const T b) { cxx_mpz r; mpz_tdiv_r_uint64(r, a, b); return r; }
199
operator %=(cxx_mpz & a,cxx_mpz const & b)200 inline cxx_mpz & operator%=(cxx_mpz & a, cxx_mpz const & b) { mpz_tdiv_r(a, a, b); return a; }
201 template <typename T, std::enable_if_t<std::is_integral<T>::value && std::is_unsigned<T>::value, int> = 0 >
operator %=(cxx_mpz & a,const T b)202 inline cxx_mpz & operator%=(cxx_mpz & a, const T b) { mpz_tdiv_r_uint64(a, a, b); return a; }
203
operator <<=(cxx_mpz & a,const mp_bitcnt_t s)204 inline cxx_mpz & operator<<=(cxx_mpz & a, const mp_bitcnt_t s) { mpz_mul_2exp(a, a, s); return a; }
operator <<(cxx_mpz & a,const mp_bitcnt_t s)205 inline cxx_mpz operator<<(cxx_mpz & a, const mp_bitcnt_t s) { cxx_mpz r{a}; mpz_mul_2exp(r, r, s); return r; }
206
operator >>=(cxx_mpz & a,const mp_bitcnt_t s)207 inline cxx_mpz & operator>>=(cxx_mpz & a, const mp_bitcnt_t s) { mpz_tdiv_q_2exp(a, a, s); return a; }
operator >>(cxx_mpz & a,const mp_bitcnt_t s)208 inline cxx_mpz operator>>(cxx_mpz & a, const mp_bitcnt_t s) { cxx_mpz r{a}; mpz_tdiv_q_2exp(r, r, s); return r; }
209
210
211 #if 0
212 inline cxx_mpz operator|(cxx_mpz const & a, cxx_mpz const & b) { cxx_mpz r; mpz_ior(r, a, b); return r; }
213 inline cxx_mpz operator|(cxx_mpz const & a, const unsigned long b) { cxx_mpz r; mpz_ior(r, a, cxx_mpz(b)); return r; }
214 inline cxx_mpz & operator|=(cxx_mpz & a, cxx_mpz const & b) { mpz_ior(a, a, b); return a; }
215 inline cxx_mpz & operator|=(cxx_mpz & a, const unsigned long b) { mpz_ior(a, a, cxx_mpz(b)); return a; }
216 #endif
217
operator ==(cxx_mpq const & a,cxx_mpq const & b)218 inline bool operator==(cxx_mpq const & a, cxx_mpq const & b) { return mpq_cmp(a, b) == 0; }
operator !=(cxx_mpq const & a,cxx_mpq const & b)219 inline bool operator!=(cxx_mpq const & a, cxx_mpq const & b) { return mpq_cmp(a, b) != 0; }
operator <(cxx_mpq const & a,cxx_mpq const & b)220 inline bool operator<(cxx_mpq const & a, cxx_mpq const & b) { return mpq_cmp(a, b) < 0; }
operator >(cxx_mpq const & a,cxx_mpq const & b)221 inline bool operator>(cxx_mpq const & a, cxx_mpq const & b) { return mpq_cmp(a, b) > 0; }
operator <<(std::ostream & os,cxx_mpz const & x)222 inline std::ostream& operator<<(std::ostream& os, cxx_mpz const& x) { return os << (mpz_srcptr) x; }
operator <<(std::ostream & os,cxx_mpq const & x)223 inline std::ostream& operator<<(std::ostream& os, cxx_mpq const& x) { return os << (mpq_srcptr) x; }
operator >>(std::istream & is,cxx_mpz & x)224 inline std::istream& operator>>(std::istream& is, cxx_mpz & x) { return is >> (mpz_ptr) x; }
operator >>(std::istream & is,cxx_mpq & x)225 inline std::istream& operator>>(std::istream& is, cxx_mpq & x) { return is >> (mpq_ptr) x; }
226 #endif /* CXX_MPZ_HPP_ */
227