1 #include <Minuit2/MnUserTransformation.h>
2 #include <pybind11/pybind11.h>
3 #include <pybind11/stl.h>
4 #include <type_traits>
5 
6 namespace py = pybind11;
7 using namespace ROOT::Minuit2;
8 
9 static_assert(std::is_standard_layout<MnUserTransformation>(), "");
10 
11 struct Layout {
12   MnMachinePrecision fPrecision;
13 
14   std::vector<MinuitParameter> fParameters;
15   std::vector<unsigned int> fExtOfInt;
16 
17   SinParameterTransformation fDoubleLimTrafo;
18   SqrtUpParameterTransformation fUpperLimTrafo;
19   SqrtLowParameterTransformation fLowerLimTrafo;
20 
21   mutable std::vector<double> fCache;
22 };
23 
size(const MnUserTransformation & self)24 int size(const MnUserTransformation& self) {
25   return static_cast<int>(self.Parameters().size());
26 }
27 
iter(const MnUserTransformation & self)28 auto iter(const MnUserTransformation& self) {
29   return py::make_iterator(self.Parameters().begin(), self.Parameters().end());
30 }
31 
getitem(const MnUserTransformation & self,int i)32 const auto& getitem(const MnUserTransformation& self, int i) {
33   if (i < 0) i += size(self);
34   if (i < 0 || i >= size(self)) throw py::index_error();
35   return self.Parameter(i);
36 }
37 
bind_usertransformation(py::module m)38 void bind_usertransformation(py::module m) {
39   py::class_<MnUserTransformation>(m, "MnUserTransformation")
40 
41       .def(py::init<>())
42 
43       .def("name", &MnUserTransformation::GetName)
44       .def("index", &MnUserTransformation::FindIndex)
45       .def("ext2int", &MnUserTransformation::Ext2int)
46       .def("int2ext", &MnUserTransformation::Int2ext)
47       .def("dint2ext", &MnUserTransformation::DInt2Ext)
48       .def("ext_of_int", &MnUserTransformation::ExtOfInt)
49       .def("int_of_ext", &MnUserTransformation::IntOfExt)
50       .def_property_readonly("variable_parameters",
51                              &MnUserTransformation::VariableParameters)
52 
53       .def("__len__", size)
54       .def("__iter__", iter)
55       .def("__getitem__", getitem)
56 
57       .def(py::pickle(
58           [](const MnUserTransformation& self) {
59             const auto d = reinterpret_cast<const Layout*>(&self);
60             return py::make_tuple(self.Precision().Eps(), self.Parameters(),
61                                   d->fExtOfInt, self.InitialParValues());
62           },
63           [](py::tuple tp) {
64             if (tp.size() != 4) throw std::runtime_error("invalid state");
65 
66             MnUserTransformation tr;
67             tr.SetPrecision(tp[0].cast<double>());
68 
69             // evil workaround, will segfault or cause UB if source layout changes
70             auto d = reinterpret_cast<Layout*>(&tr);
71             d->fParameters = tp[1].cast<std::vector<MinuitParameter>>();
72             d->fExtOfInt = tp[2].cast<std::vector<unsigned>>();
73             d->fCache = tp[3].cast<std::vector<double>>();
74             return tr;
75           }))
76 
77       ;
78 }
79