1 #include "bispectrum.hpp"
2
3 #include <pybind11/numpy.h>
4 #include <pybind11/pybind11.h>
5
6 namespace py = pybind11;
7
PYBIND11_MODULE(bs,m)8 PYBIND11_MODULE(bs, m)
9 {
10 m.doc() = "Bispectrum descriptor.";
11
12 py::class_<Bispectrum>(m, "Bispectrum")
13 .def(py::init<double const,
14 int const,
15 int const,
16 int const,
17 double const,
18 int const,
19 int const>())
20
21 .def(
22 "set_cutoff",
23 [](Bispectrum & d, char * name, py::array_t<double> rcuts) {
24 d.set_cutoff(name, rcuts.shape(0), rcuts.data(0));
25 },
26 py::arg("name"),
27 py::arg("rcuts").noconvert())
28
29 .def(
30 "set_weight",
31 [](Bispectrum & d, py::array_t<double> weight) {
32 d.set_weight(weight.size(), weight.data(0));
33 },
34 py::arg("weight").noconvert())
35
36 .def(
37 "compute_zeta",
38 [](Bispectrum & d,
39 py::array_t<double> coords,
40 py::array_t<int> species,
41 py::array_t<int> neighlist,
42 py::array_t<int> numneigh,
43 py::array_t<int> image,
44 int Natoms,
45 int Ncontrib,
46 int Ndescriptor) {
47 // create empty vectors to hold return data
48 std::vector<double> zeta(Ncontrib * Ndescriptor, 0.0);
49
50 d.compute_B(coords.data(0),
51 species.data(0),
52 neighlist.data(0),
53 numneigh.data(0),
54 image.data(0),
55 Natoms,
56 Ncontrib,
57 zeta.data(),
58 nullptr);
59
60 // pack zeta into a buffer that numpy array can understand
61 auto zeta_2D = py::array(py::buffer_info(
62 zeta.data(), // data pointer
63 sizeof(double), // size of one element
64 py::format_descriptor<double>::format(), // Python struct-style
65 // format descriptor
66 2, // dimension
67 {Ncontrib, Ndescriptor}, // size of each dimension
68 {sizeof(double) * Ndescriptor, sizeof(double)}
69 // stride (in bytes) for each dimension
70 ));
71
72 return zeta_2D;
73 },
74 py::arg("coords").noconvert(),
75 py::arg("species").noconvert(),
76 py::arg("neighlist").noconvert(),
77 py::arg("numneigh").noconvert(),
78 py::arg("image").noconvert(),
79 py::arg("Natoms"),
80 py::arg("Ncontrib"),
81 py::arg("Ndescriptor"))
82
83 .def(
84 "compute_zeta_and_dzeta_dr",
85 [](Bispectrum & d,
86 py::array_t<double> coords,
87 py::array_t<int> species,
88 py::array_t<int> neighlist,
89 py::array_t<int> numneigh,
90 py::array_t<int> image,
91 int Natoms,
92 int Ncontrib,
93 int Ndescriptor) {
94 // create empty vectors to hold return data
95 std::vector<double> zeta(Ncontrib * Ndescriptor, 0.0);
96 std::vector<double> dzeta_dr(Ncontrib * Ndescriptor * Ncontrib * 3,
97 0.0);
98
99 d.compute_B(coords.data(0),
100 species.data(0),
101 neighlist.data(0),
102 numneigh.data(0),
103 image.data(0),
104 Natoms,
105 Ncontrib,
106 zeta.data(),
107 dzeta_dr.data());
108
109 // pack zeta into a buffer that numpy array can understand
110 auto zeta_2D = py::array(py::buffer_info(
111 zeta.data(), // data pointer
112 sizeof(double), // size of one element
113 py::format_descriptor<double>::format(), // Python struct-style
114 // format descriptor
115 2, // dimension
116 {Ncontrib, Ndescriptor}, // size of each dimension
117 {sizeof(double) * Ndescriptor, sizeof(double)}
118 // stride (in bytes) for each dimension
119 ));
120
121 // pack dzeta into a buffer that numpy array can understand
122 auto dzeta_dr_4D = py::array(
123 py::buffer_info(dzeta_dr.data(),
124 sizeof(double),
125 py::format_descriptor<double>::format(),
126 4,
127 {Ncontrib, Ndescriptor, Ncontrib, 3},
128 {sizeof(double) * Ndescriptor * Ncontrib * 3,
129 sizeof(double) * Ncontrib * 3,
130 sizeof(double) * 3,
131 sizeof(double)}));
132
133 py::tuple t(2);
134
135 t[0] = zeta_2D;
136 t[1] = dzeta_dr_4D;
137
138 return t;
139 },
140 py::arg("coords").noconvert(),
141 py::arg("species").noconvert(),
142 py::arg("neighlist").noconvert(),
143 py::arg("numneigh").noconvert(),
144 py::arg("image").noconvert(),
145 py::arg("Natoms"),
146 py::arg("Ncontrib"),
147 py::arg("Ndescriptor"),
148 "Return (zeta, dzeta_dr)");
149 }
150