1 // Copyright 2020, 2021 PaGMO development team
2 //
3 // This file is part of the pygmo library.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla
6 // Public License v. 2.0. If a copy of the MPL was not distributed
7 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
8 
9 #ifndef PYGMO_S11N_WRAPPERS_HPP
10 #define PYGMO_S11N_WRAPPERS_HPP
11 
12 #include <cstddef>
13 #include <sstream>
14 #include <string>
15 #include <type_traits>
16 #include <vector>
17 
18 #include <boost/archive/binary_iarchive.hpp>
19 #include <boost/archive/binary_oarchive.hpp>
20 #include <boost/numeric/conversion/cast.hpp>
21 #include <boost/serialization/base_object.hpp>
22 #include <boost/serialization/binary_object.hpp>
23 
24 #include <pybind11/pybind11.h>
25 
26 #include "common_utils.hpp"
27 
28 namespace pygmo
29 {
30 
31 namespace py = pybind11;
32 
33 // Two helpers to implement s11n for the *_inner classes
34 // specialisations for py::object. d is an instance of the *_inner
35 // class, Base its base type.
36 template <typename Base, typename Archive, typename Derived>
inner_class_save(Archive & ar,const Derived & d)37 inline void inner_class_save(Archive &ar, const Derived &d)
38 {
39     static_assert(std::is_base_of_v<Base, Derived>);
40 
41     // Serialize the base class.
42     ar << boost::serialization::base_object<Base>(d);
43 
44     // This will dump m_value into a bytes object..
45     auto tmp = py::module::import("pygmo").attr("get_serialization_backend")().attr("dumps")(d.m_value);
46 
47     // This gives a null-terminated char * to the internal
48     // content of the bytes object.
49     auto ptr = PyBytes_AsString(tmp.ptr());
50     if (!ptr) {
51         py_throw(PyExc_TypeError, "The serialization backend's dumps() function did not return a bytes object");
52     }
53 
54     // NOTE: this will be the length of the bytes object *without* the terminator.
55     const auto size = boost::numeric_cast<std::size_t>(py::len(tmp));
56 
57     // Save the binary size.
58     ar << size;
59 
60     // Save the binary object.
61     ar << boost::serialization::make_binary_object(ptr, size);
62 }
63 
64 template <typename Base, typename Archive, typename Derived>
inner_class_load(Archive & ar,Derived & d)65 inline void inner_class_load(Archive &ar, Derived &d)
66 {
67     static_assert(std::is_base_of_v<Base, Derived>);
68 
69     // Deserialize the base class.
70     ar >> boost::serialization::base_object<Base>(d);
71 
72     // Recover the size.
73     std::size_t size{};
74     ar >> size;
75 
76     // Recover the binary object.
77     std::vector<char> tmp;
78     tmp.resize(boost::numeric_cast<decltype(tmp.size())>(size));
79     ar >> boost::serialization::make_binary_object(tmp.data(), size);
80 
81     // Deserialise and assign.
82     auto b = py::bytes(tmp.data(), boost::numeric_cast<py::size_t>(size));
83     d.m_value = py::module::import("pygmo").attr("get_serialization_backend")().attr("loads")(b);
84 }
85 
86 // Helpers to implement pickling on top of Boost.Serialization.
87 template <typename T>
pickle_getstate_wrapper(const T & x)88 inline py::tuple pickle_getstate_wrapper(const T &x)
89 {
90     std::ostringstream oss;
91     {
92         boost::archive::binary_oarchive oa(oss);
93         oa << x;
94     }
95 
96     return py::make_tuple(py::bytes(oss.str()));
97 }
98 
99 template <typename T>
pickle_setstate_wrapper(py::tuple state)100 inline T pickle_setstate_wrapper(py::tuple state)
101 {
102     if (py::len(state) != 1) {
103         py_throw(PyExc_ValueError, ("The state tuple passed to the deserialization wrapper "
104                                     "must have 1 element, but instead it has "
105                                     + std::to_string(py::len(state)) + " element(s)")
106                                        .c_str());
107     }
108 
109     auto ptr = PyBytes_AsString(state[0].ptr());
110     if (!ptr) {
111         py_throw(PyExc_TypeError, "A bytes object is needed in the deserialization wrapper");
112     }
113 
114     std::istringstream iss;
115     iss.str(std::string(ptr, ptr + py::len(state[0])));
116     T x;
117     {
118         boost::archive::binary_iarchive iarchive(iss);
119         iarchive >> x;
120     }
121 
122     return x;
123 }
124 
125 } // namespace pygmo
126 
127 #endif
128