1 // Copyright (c) 2017 Commissariat à l'énergie atomique et aux énergies alternatives (CEA)
2 // Copyright (c) 2017 Centre national de la recherche scientifique (CNRS)
3 // Copyright (c) 2019-2020 Simons Foundation
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0.txt
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 // Authors: Gregory Kramida, Olivier Parcollet, Nils Wentzell
18 
19 #pragma once
20 #include <functional>
21 #include "../pyref.hpp"
22 
23 namespace cpp2py {
24 
25   // ---- function ----
26 
27   // a few useful meta tricks
28   template <int N> struct _int {};
29   template <int... N> struct index_seq {};
30   template <typename U> struct nop {};
31 
32   template <int N> struct _make_index_seq;
33   template <int N> using make_index_seq = typename _make_index_seq<N>::type;
34   template <> struct _make_index_seq<0> { using type = index_seq<>; };
35   template <> struct _make_index_seq<1> { using type = index_seq<0>; };
36   template <> struct _make_index_seq<2> { using type = index_seq<0, 1>; };
37   template <> struct _make_index_seq<3> { using type = index_seq<0, 1, 2>; };
38   template <> struct _make_index_seq<4> { using type = index_seq<0, 1, 2, 3>; };
39   template <> struct _make_index_seq<5> { using type = index_seq<0, 1, 2, 3, 4>; };
40 
41   // details
42   template <bool B> struct _bool {};
43   template <typename T> struct _is_pointer : _bool<false> {};
44   template <typename T> struct _is_pointer<T *> : _bool<true> {};
45   template <> struct _is_pointer<PyObject *> : _bool<false> {}; // yes, false, it is a special case...
46 
47   // adapter needed for parsing with PyArg_ParseTupleAndKeywords later in the functions
converter_for_parser_fnt_(PyObject * ob,T * p,_bool<false>)48   template <typename T> static int converter_for_parser_fnt_(PyObject *ob, T *p, _bool<false>) {
49     if (!py_converter<T>::is_convertible(ob, true)) return 0;
50     *p = std::move(convert_from_python<T>(ob)); // non wrapped types are converted to values, they can be moved !
51     return 1;
52   }
converter_for_parser_fnt_(PyObject * ob,T ** p,_bool<true>)53   template <typename T> static int converter_for_parser_fnt_(PyObject *ob, T **p, _bool<true>) {
54     if (!convertible_from_python<T>(ob)) return 0;
55     *p = &(convert_from_python<T>(ob));
56     return 1;
57   }
converter_for_parser_fnt(PyObject * ob,T * p)58   template <typename T> static int converter_for_parser_fnt(PyObject *ob, T *p) { return converter_for_parser_fnt_(ob, p, _is_pointer<T>()); }
59 
60   template <typename R, typename... T> struct py_converter<std::function<R(T...)>> {
61 
62     static_assert(sizeof...(T) < 5, "More than 5 variables not implemented");
63     typedef struct {
64       PyObject_HEAD std::function<R(T...)> *_c;
65     } std_function;
66 
std_function_newcpp2py::py_converter67     static PyObject *std_function_new(PyTypeObject *type, PyObject *args, PyObject *kwds) {
68       std_function *self;
69       self = (std_function *)type->tp_alloc(type, 0);
70       if (self != NULL) { self->_c = new std::function<R(T...)>{}; }
71       return (PyObject *)self;
72     }
73 
std_function_dealloccpp2py::py_converter74     static void std_function_dealloc(std_function *self) {
75       delete self->_c;
76       Py_TYPE(self)->tp_free((PyObject *)self);
77     }
78 
79     // technical details to implement the __call function of the wrapping python object, cf below
80     // we are using the unpack trick of the apply proposal for the C++ standard : cf XXXX
81     //
82     // specialise the convertion of the return type in the void case
83     template <typename RR, typename TU, int... Is>
_call_and_treat_returncpp2py::py_converter84     static PyObject *_call_and_treat_return(nop<RR>, std_function *pyf, TU const &tu, index_seq<Is...>) {
85       return py_converter<RR>::c2py(pyf->_c->operator()(std::get<Is>(tu)...));
86     }
_call_and_treat_returncpp2py::py_converter87     template <typename TU, int... Is> static PyObject *_call_and_treat_return(nop<void>, std_function *pyf, TU const &tu, index_seq<Is...>) {
88       pyf->_c->operator()(std::get<Is>(tu)...);
89       Py_RETURN_NONE;
90     }
91 
92     using arg_tuple_t = std::tuple<T...>;
93     using _int_max    = _int<sizeof...(T) - 1>;
94 
_parsecpp2py::py_converter95     template <typename... U> static int _parse(_int<-1>, PyObject *args, arg_tuple_t &tu, U... u) {
96       const char *format = "O&O&O&O&O&"; // change 5 for more arguments.
97       static_assert(sizeof...(T) <= 5, "More than 5 not implement. Easy to do ...");
98       return PyArg_ParseTuple(args, format + 2 * (5 - sizeof...(T)), u...);
99     }
_parsecpp2py::py_converter100     template <int N, typename... U> static int _parse(_int<N>, PyObject *args, arg_tuple_t &tu, U... u) {
101       return _parse(_int<N - 1>(), args, tu, converter_for_parser_fnt<typename std::tuple_element<N, typename std::decay<arg_tuple_t>::type>::type>,
102                     &std::get<N>(tu), u...);
103     }
104 
105     // the call function object ...
106     // TODO : ADD THE REF AND POINTERS  in x ??
std_function_callcpp2py::py_converter107     static PyObject *std_function_call(PyObject *self, PyObject *args, PyObject *kwds) {
108       arg_tuple_t x;
109       if (!_parse(_int_max(), args, x)) return NULL;
110       try {
111         return _call_and_treat_return(nop<R>(), (std_function *)self, x, make_index_seq<sizeof...(T)>());
112       }
113       CATCH_AND_RETURN("calling C++ std::function ", NULL);
114     }
115 
get_typecpp2py::py_converter116     static PyTypeObject get_type() {
117       return {
118          PyVarObject_HEAD_INIT(NULL, 0)       /*ob_size*/
119          "std_function",                      /*tp_name*/
120          sizeof(std_function),                /*tp_basicsize*/
121          0,                                   /*tp_itemsize*/
122 
123          (destructor)std_function_dealloc,    /*tp_dealloc*/
124          0,                                   /*tp_print*/
125          0,                                   /*tp_getattr*/
126          0,                                   /*tp_setattr*/
127          0,                                   /*tp_compare in py2, tp_as_async in py3*/
128 
129          0,                                   /*tp_repr*/
130 
131          0,                                   /*tp_as_number*/
132          0,                                   /*tp_as_sequence*/
133          0,                                   /*tp_as_mapping*/
134 
135          0,                                   /*tp_hash */
136          (ternaryfunc)std_function_call,      /*tp_call*/
137          0,                                   /*tp_str*/
138          0,                                   /*tp_getattro*/
139          0,                                   /*tp_setattro*/
140 
141          0,                                   /*tp_as_buffer*/
142 
143          Py_TPFLAGS_DEFAULT,                  /*tp_flags*/
144 
145          "Internal wrapper of std::function", /* tp_doc */
146 
147          0,                                   /* tp_traverse */
148 
149          0,                                   /* tp_clear */
150 
151          0,                                   /* tp_richcompare */
152 
153          0,                                   /* tp_weaklistoffset */
154 
155          0,                                   /* tp_iter */
156          0,                                   /* tp_iternext */
157 
158          0,                                   /* tp_methods */
159          0,                                   /* tp_members */
160          0,                                   /* tp_getset */
161          0,                                   /* tp_base */
162          0,                                   /* tp_dict */
163          0,                                   /* tp_descr_get */
164          0,                                   /* tp_descr_set */
165          0,                                   /* tp_dictoffset */
166          0,                                   /* tp_init */
167          0,                                   /* tp_alloc */
168          (newfunc)std_function_new,           /* tp_new */
169       };
170     }
171 
ensure_type_readycpp2py::py_converter172     static void ensure_type_ready(PyTypeObject &Type, bool &ready) {
173       if (!ready) {
174         Type = get_type();
175         if (PyType_Ready(&Type) < 0) std::cerr << " Warning : ensure_type_ready has failed in function-lambda C++/Python converter " << std::endl;
176         ready = true;
177       }
178     }
179 
180     // U can be anything, typically a lambda
c2pycpp2py::py_converter181     template <typename U> static PyObject *c2py(U &&x) {
182       std_function *self;
183       static PyTypeObject Type;
184       static bool ready = false;
185       ensure_type_ready(Type, ready);
186       self = (std_function *)Type.tp_alloc(&Type, 0);
187       if (self != NULL) { self->_c = new std::function<R(T...)>{std::forward<U>(x)}; }
188       return (PyObject *)self;
189     }
190 
is_convertiblecpp2py::py_converter191     static bool is_convertible(PyObject *ob, bool raise_exception) {
192       if (PyCallable_Check(ob)) return true;
193       if (raise_exception) { PyErr_SetString(PyExc_TypeError, ("Cannot convert "s + to_string(ob) + " std::function as it is not callable"s).c_str()); }
194       return false;
195     }
196 
py2ccpp2py::py_converter197     static std::function<R(T...)> py2c(PyObject *ob) {
198       static PyTypeObject Type;
199       static bool ready = false;
200       ensure_type_ready(Type, ready);
201       // If we convert a wrapped std::function, just extract it.
202       if (PyObject_TypeCheck(ob, &Type)) { return *(((std_function *)ob)->_c); }
203       // otherwise, we build a new std::function around the python function
204       pyref py_fnt = borrowed(ob);
205       auto l       = [py_fnt](T... x) mutable -> R { // py_fnt is a pyref, it will keep the ref and manage the ref counting...
206         pyref ret  = PyObject_CallFunctionObjArgs(py_fnt, (PyObject *)pyref(convert_to_python(x))..., NULL);
207         if (not py_converter<R>::is_convertible(ret, false)) {
208           CPP2PY_RUNTIME_ERROR << "\n Cannot convert function result " << to_string(ret) << " from python to C++";
209         }
210         return py_converter<R>::py2c(ret);
211       };
212       return l;
213     }
214   };
215 
216 } // namespace cpp2py
217