1 #include "bispectrum.hpp"
2 
3 #include <pybind11/numpy.h>
4 #include <pybind11/pybind11.h>
5 
6 namespace py = pybind11;
7 
PYBIND11_MODULE(bs,m)8 PYBIND11_MODULE(bs, m)
9 {
10   m.doc() = "Bispectrum descriptor.";
11 
12   py::class_<Bispectrum>(m, "Bispectrum")
13       .def(py::init<double const,
14                     int const,
15                     int const,
16                     int const,
17                     double const,
18                     int const,
19                     int const>())
20 
21       .def(
22           "set_cutoff",
23           [](Bispectrum & d, char * name, py::array_t<double> rcuts) {
24             d.set_cutoff(name, rcuts.shape(0), rcuts.data(0));
25           },
26           py::arg("name"),
27           py::arg("rcuts").noconvert())
28 
29       .def(
30           "set_weight",
31           [](Bispectrum & d, py::array_t<double> weight) {
32             d.set_weight(weight.size(), weight.data(0));
33           },
34           py::arg("weight").noconvert())
35 
36       .def(
37           "compute_zeta",
38           [](Bispectrum & d,
39              py::array_t<double> coords,
40              py::array_t<int> species,
41              py::array_t<int> neighlist,
42              py::array_t<int> numneigh,
43              py::array_t<int> image,
44              int Natoms,
45              int Ncontrib,
46              int Ndescriptor) {
47             // create empty vectors to hold return data
48             std::vector<double> zeta(Ncontrib * Ndescriptor, 0.0);
49 
50             d.compute_B(coords.data(0),
51                         species.data(0),
52                         neighlist.data(0),
53                         numneigh.data(0),
54                         image.data(0),
55                         Natoms,
56                         Ncontrib,
57                         zeta.data(),
58                         nullptr);
59 
60             // pack zeta into a buffer that numpy array can understand
61             auto zeta_2D = py::array(py::buffer_info(
62                 zeta.data(),  // data pointer
63                 sizeof(double),  // size of one element
64                 py::format_descriptor<double>::format(),  // Python struct-style
65                                                           // format descriptor
66                 2,  // dimension
67                 {Ncontrib, Ndescriptor},  // size of each dimension
68                 {sizeof(double) * Ndescriptor, sizeof(double)}
69                 // stride (in bytes) for each dimension
70                 ));
71 
72             return zeta_2D;
73           },
74           py::arg("coords").noconvert(),
75           py::arg("species").noconvert(),
76           py::arg("neighlist").noconvert(),
77           py::arg("numneigh").noconvert(),
78           py::arg("image").noconvert(),
79           py::arg("Natoms"),
80           py::arg("Ncontrib"),
81           py::arg("Ndescriptor"))
82 
83       .def(
84           "compute_zeta_and_dzeta_dr",
85           [](Bispectrum & d,
86              py::array_t<double> coords,
87              py::array_t<int> species,
88              py::array_t<int> neighlist,
89              py::array_t<int> numneigh,
90              py::array_t<int> image,
91              int Natoms,
92              int Ncontrib,
93              int Ndescriptor) {
94             // create empty vectors to hold return data
95             std::vector<double> zeta(Ncontrib * Ndescriptor, 0.0);
96             std::vector<double> dzeta_dr(Ncontrib * Ndescriptor * Ncontrib * 3,
97                                          0.0);
98 
99             d.compute_B(coords.data(0),
100                         species.data(0),
101                         neighlist.data(0),
102                         numneigh.data(0),
103                         image.data(0),
104                         Natoms,
105                         Ncontrib,
106                         zeta.data(),
107                         dzeta_dr.data());
108 
109             // pack zeta into a buffer that numpy array can understand
110             auto zeta_2D = py::array(py::buffer_info(
111                 zeta.data(),  // data pointer
112                 sizeof(double),  // size of one element
113                 py::format_descriptor<double>::format(),  // Python struct-style
114                                                           // format descriptor
115                 2,  // dimension
116                 {Ncontrib, Ndescriptor},  // size of each dimension
117                 {sizeof(double) * Ndescriptor, sizeof(double)}
118                 // stride (in bytes) for each dimension
119                 ));
120 
121             // pack dzeta into a buffer that numpy array can understand
122             auto dzeta_dr_4D = py::array(
123                 py::buffer_info(dzeta_dr.data(),
124                                 sizeof(double),
125                                 py::format_descriptor<double>::format(),
126                                 4,
127                                 {Ncontrib, Ndescriptor, Ncontrib, 3},
128                                 {sizeof(double) * Ndescriptor * Ncontrib * 3,
129                                  sizeof(double) * Ncontrib * 3,
130                                  sizeof(double) * 3,
131                                  sizeof(double)}));
132 
133             py::tuple t(2);
134 
135             t[0] = zeta_2D;
136             t[1] = dzeta_dr_4D;
137 
138             return t;
139           },
140           py::arg("coords").noconvert(),
141           py::arg("species").noconvert(),
142           py::arg("neighlist").noconvert(),
143           py::arg("numneigh").noconvert(),
144           py::arg("image").noconvert(),
145           py::arg("Natoms"),
146           py::arg("Ncontrib"),
147           py::arg("Ndescriptor"),
148           "Return (zeta, dzeta_dr)");
149 }
150