1 #include <Minuit2/MnUserTransformation.h>
2 #include <pybind11/pybind11.h>
3 #include <pybind11/stl.h>
4 #include <type_traits>
5
6 namespace py = pybind11;
7 using namespace ROOT::Minuit2;
8
9 static_assert(std::is_standard_layout<MnUserTransformation>(), "");
10
11 struct Layout {
12 MnMachinePrecision fPrecision;
13
14 std::vector<MinuitParameter> fParameters;
15 std::vector<unsigned int> fExtOfInt;
16
17 SinParameterTransformation fDoubleLimTrafo;
18 SqrtUpParameterTransformation fUpperLimTrafo;
19 SqrtLowParameterTransformation fLowerLimTrafo;
20
21 mutable std::vector<double> fCache;
22 };
23
size(const MnUserTransformation & self)24 int size(const MnUserTransformation& self) {
25 return static_cast<int>(self.Parameters().size());
26 }
27
iter(const MnUserTransformation & self)28 auto iter(const MnUserTransformation& self) {
29 return py::make_iterator(self.Parameters().begin(), self.Parameters().end());
30 }
31
getitem(const MnUserTransformation & self,int i)32 const auto& getitem(const MnUserTransformation& self, int i) {
33 if (i < 0) i += size(self);
34 if (i < 0 || i >= size(self)) throw py::index_error();
35 return self.Parameter(i);
36 }
37
bind_usertransformation(py::module m)38 void bind_usertransformation(py::module m) {
39 py::class_<MnUserTransformation>(m, "MnUserTransformation")
40
41 .def(py::init<>())
42
43 .def("name", &MnUserTransformation::GetName)
44 .def("index", &MnUserTransformation::FindIndex)
45 .def("ext2int", &MnUserTransformation::Ext2int)
46 .def("int2ext", &MnUserTransformation::Int2ext)
47 .def("dint2ext", &MnUserTransformation::DInt2Ext)
48 .def("ext_of_int", &MnUserTransformation::ExtOfInt)
49 .def("int_of_ext", &MnUserTransformation::IntOfExt)
50 .def_property_readonly("variable_parameters",
51 &MnUserTransformation::VariableParameters)
52
53 .def("__len__", size)
54 .def("__iter__", iter)
55 .def("__getitem__", getitem)
56
57 .def(py::pickle(
58 [](const MnUserTransformation& self) {
59 const auto d = reinterpret_cast<const Layout*>(&self);
60 return py::make_tuple(self.Precision().Eps(), self.Parameters(),
61 d->fExtOfInt, self.InitialParValues());
62 },
63 [](py::tuple tp) {
64 if (tp.size() != 4) throw std::runtime_error("invalid state");
65
66 MnUserTransformation tr;
67 tr.SetPrecision(tp[0].cast<double>());
68
69 // evil workaround, will segfault or cause UB if source layout changes
70 auto d = reinterpret_cast<Layout*>(&tr);
71 d->fParameters = tp[1].cast<std::vector<MinuitParameter>>();
72 d->fExtOfInt = tp[2].cast<std::vector<unsigned>>();
73 d->fCache = tp[3].cast<std::vector<double>>();
74 return tr;
75 }))
76
77 ;
78 }
79