1 // Copyright (c) 2017 Commissariat à l'énergie atomique et aux énergies alternatives (CEA) 2 // Copyright (c) 2017 Centre national de la recherche scientifique (CNRS) 3 // Copyright (c) 2020 Simons Foundation 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0.txt 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // 17 // Authors: Olivier Parcollet, Nils Wentzell 18 19 #pragma once 20 #include "../pyref.hpp" 21 22 #include <numpy/arrayobject.h> 23 24 namespace cpp2py { 25 26 // --- complex 27 28 template <> struct py_converter<std::complex<double>> { c2pycpp2py::py_converter29 static PyObject *c2py(std::complex<double> x) { return PyComplex_FromDoubles(x.real(), x.imag()); } py2ccpp2py::py_converter30 static std::complex<double> py2c(PyObject *ob) { 31 32 if (PyArray_CheckScalar(ob)) { 33 // Convert NPY Scalar Type to Builtin Type 34 pyref py_builtin = PyObject_CallMethod(ob, "item", NULL); 35 if (PyComplex_Check(py_builtin)) { 36 auto r = PyComplex_AsCComplex(py_builtin); 37 return {r.real, r.imag}; 38 } else { 39 return PyFloat_AsDouble(py_builtin); 40 } 41 } 42 43 if (PyComplex_Check(ob)) { 44 auto r = PyComplex_AsCComplex(ob); 45 return {r.real, r.imag}; 46 } 47 return PyFloat_AsDouble(ob); 48 } is_convertiblecpp2py::py_converter49 static bool is_convertible(PyObject *ob, bool raise_exception) { 50 if (PyComplex_Check(ob) || PyFloat_Check(ob) || PyLong_Check(ob)) return true; 51 if (PyArray_CheckScalar(ob)) { 52 pyref py_arr = PyArray_FromScalar(ob, NULL); 53 if (PyArray_ISINTEGER((PyObject*)py_arr) or PyArray_ISFLOAT((PyObject*)py_arr) or PyArray_ISCOMPLEX((PyObject*)py_arr)) return true; 54 } 55 if (raise_exception) { PyErr_SetString(PyExc_TypeError, ("Cannot convert "s + to_string(ob) + " to complex"s).c_str()); } 56 return false; 57 } 58 }; 59 60 } // namespace cpp2py 61