1 // Copyright (c) 2017 Commissariat à l'énergie atomique et aux énergies alternatives (CEA)
2 // Copyright (c) 2017 Centre national de la recherche scientifique (CNRS)
3 // Copyright (c) 2020 Simons Foundation
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0.txt
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 // Authors: Olivier Parcollet, Nils Wentzell
18 
19 #pragma once
20 #include "../pyref.hpp"
21 
22 #include <numpy/arrayobject.h>
23 
24 namespace cpp2py {
25 
26   // --- complex
27 
28   template <> struct py_converter<std::complex<double>> {
c2pycpp2py::py_converter29     static PyObject *c2py(std::complex<double> x) { return PyComplex_FromDoubles(x.real(), x.imag()); }
py2ccpp2py::py_converter30     static std::complex<double> py2c(PyObject *ob) {
31 
32       if (PyArray_CheckScalar(ob)) {
33         // Convert NPY Scalar Type to Builtin Type
34         pyref py_builtin = PyObject_CallMethod(ob, "item", NULL);
35         if (PyComplex_Check(py_builtin)) {
36           auto r = PyComplex_AsCComplex(py_builtin);
37           return {r.real, r.imag};
38         } else {
39           return PyFloat_AsDouble(py_builtin);
40         }
41       }
42 
43       if (PyComplex_Check(ob)) {
44         auto r = PyComplex_AsCComplex(ob);
45         return {r.real, r.imag};
46       }
47       return PyFloat_AsDouble(ob);
48     }
is_convertiblecpp2py::py_converter49     static bool is_convertible(PyObject *ob, bool raise_exception) {
50       if (PyComplex_Check(ob) || PyFloat_Check(ob) || PyLong_Check(ob)) return true;
51       if (PyArray_CheckScalar(ob)) {
52         pyref py_arr = PyArray_FromScalar(ob, NULL);
53         if (PyArray_ISINTEGER((PyObject*)py_arr) or PyArray_ISFLOAT((PyObject*)py_arr) or PyArray_ISCOMPLEX((PyObject*)py_arr)) return true;
54       }
55       if (raise_exception) { PyErr_SetString(PyExc_TypeError, ("Cannot convert "s + to_string(ob) + " to complex"s).c_str()); }
56       return false;
57     }
58   };
59 
60 } // namespace cpp2py
61