1 #include "openmc/tallies/filter_zernike.h"
2 
3 #include <cmath>
4 #include <sstream>
5 #include <utility>  // For pair
6 
7 #include <fmt/core.h>
8 #include <gsl/gsl>
9 
10 #include "openmc/capi.h"
11 #include "openmc/error.h"
12 #include "openmc/math_functions.h"
13 #include "openmc/xml_interface.h"
14 
15 namespace openmc {
16 
17 //==============================================================================
18 // ZernikeFilter implementation
19 //==============================================================================
20 
21 void
from_xml(pugi::xml_node node)22 ZernikeFilter::from_xml(pugi::xml_node node)
23 {
24   set_order(std::stoi(get_node_value(node, "order")));
25   x_ = std::stod(get_node_value(node, "x"));
26   y_ = std::stod(get_node_value(node, "y"));
27   r_ = std::stod(get_node_value(node, "r"));
28 }
29 
30 void
get_all_bins(const Particle & p,TallyEstimator estimator,FilterMatch & match) const31 ZernikeFilter::get_all_bins(const Particle& p, TallyEstimator estimator,
32                             FilterMatch& match) const
33 {
34   // Determine the normalized (r,theta) coordinates.
35   double x = p.r().x - x_;
36   double y = p.r().y - y_;
37   double r = std::sqrt(x*x + y*y) / r_;
38   double theta = std::atan2(y, x);
39 
40   if (r <= 1.0) {
41     // Compute and return the Zernike weights.
42     vector<double> zn(n_bins_);
43     calc_zn(order_, r, theta, zn.data());
44     for (int i = 0; i < n_bins_; i++) {
45       match.bins_.push_back(i);
46       match.weights_.push_back(zn[i]);
47     }
48   }
49 }
50 
51 void
to_statepoint(hid_t filter_group) const52 ZernikeFilter::to_statepoint(hid_t filter_group) const
53 {
54   Filter::to_statepoint(filter_group);
55   write_dataset(filter_group, "order", order_);
56   write_dataset(filter_group, "x", x_);
57   write_dataset(filter_group, "y", y_);
58   write_dataset(filter_group, "r", r_);
59 }
60 
61 std::string
text_label(int bin) const62 ZernikeFilter::text_label(int bin) const
63 {
64   Expects(bin >= 0 && bin < n_bins_);
65   for (int n = 0; n < order_+1; n++) {
66     int last = (n + 1) * (n + 2) / 2;
67     if (bin < last) {
68       int first = last - (n + 1);
69       int m = -n + (bin - first) * 2;
70       return fmt::format("Zernike expansion, Z{},{}", n, m);
71     }
72   }
73   UNREACHABLE();
74 }
75 
76 void
set_order(int order)77 ZernikeFilter::set_order(int order)
78 {
79   if (order < 0) {
80     throw std::invalid_argument{"Zernike order must be non-negative."};
81   }
82   order_ = order;
83   n_bins_ = ((order+1) * (order+2)) / 2;
84 }
85 
86 //==============================================================================
87 // ZernikeRadialFilter implementation
88 //==============================================================================
89 
90 void
get_all_bins(const Particle & p,TallyEstimator estimator,FilterMatch & match) const91 ZernikeRadialFilter::get_all_bins(const Particle& p, TallyEstimator estimator,
92                                   FilterMatch& match) const
93 {
94   // Determine the normalized radius coordinate.
95   double x = p.r().x - x_;
96   double y = p.r().y - y_;
97   double r = std::sqrt(x*x + y*y) / r_;
98 
99   if (r <= 1.0) {
100     // Compute and return the Zernike weights.
101     vector<double> zn(n_bins_);
102     calc_zn_rad(order_, r, zn.data());
103     for (int i = 0; i < n_bins_; i++) {
104       match.bins_.push_back(i);
105       match.weights_.push_back(zn[i]);
106     }
107   }
108 }
109 
110 std::string
text_label(int bin) const111 ZernikeRadialFilter::text_label(int bin) const
112 {
113   return "Zernike expansion, Z" + std::to_string(2*bin) + ",0";
114 }
115 
116 void
set_order(int order)117 ZernikeRadialFilter::set_order(int order)
118 {
119   ZernikeFilter::set_order(order);
120   n_bins_ = order / 2 + 1;
121 }
122 
123 //==============================================================================
124 // C-API functions
125 //==============================================================================
126 
127 std::pair<int, ZernikeFilter*>
check_zernike_filter(int32_t index)128 check_zernike_filter(int32_t index)
129 {
130   // Make sure this is a valid index to an allocated filter.
131   int err = verify_filter(index);
132   if (err) {
133     return {err, nullptr};
134   }
135 
136   // Get a pointer to the filter and downcast.
137   const auto& filt_base = model::tally_filters[index].get();
138   auto* filt = dynamic_cast<ZernikeFilter*>(filt_base);
139 
140   // Check the filter type.
141   if (!filt) {
142     set_errmsg("Not a Zernike filter.");
143     err = OPENMC_E_INVALID_TYPE;
144   }
145   return {err, filt};
146 }
147 
148 extern "C" int
openmc_zernike_filter_get_order(int32_t index,int * order)149 openmc_zernike_filter_get_order(int32_t index, int* order)
150 {
151   // Check the filter.
152   auto check_result = check_zernike_filter(index);
153   auto err = check_result.first;
154   auto filt = check_result.second;
155   if (err) return err;
156 
157   // Output the order.
158   *order = filt->order();
159   return 0;
160 }
161 
162 extern "C" int
openmc_zernike_filter_get_params(int32_t index,double * x,double * y,double * r)163 openmc_zernike_filter_get_params(int32_t index, double* x, double* y,
164                                  double* r)
165 {
166   // Check the filter.
167   auto check_result = check_zernike_filter(index);
168   auto err = check_result.first;
169   auto filt = check_result.second;
170   if (err) return err;
171 
172   // Output the params.
173   *x = filt->x();
174   *y = filt->y();
175   *r = filt->r();
176   return 0;
177 }
178 
179 extern "C" int
openmc_zernike_filter_set_order(int32_t index,int order)180 openmc_zernike_filter_set_order(int32_t index, int order)
181 {
182   // Check the filter.
183   auto check_result = check_zernike_filter(index);
184   auto err = check_result.first;
185   auto filt = check_result.second;
186   if (err) return err;
187 
188   // Update the filter.
189   filt->set_order(order);
190   return 0;
191 }
192 
193 extern "C" int
openmc_zernike_filter_set_params(int32_t index,const double * x,const double * y,const double * r)194 openmc_zernike_filter_set_params(int32_t index, const double* x,
195                                  const double* y, const double* r)
196 {
197   // Check the filter.
198   auto check_result = check_zernike_filter(index);
199   auto err = check_result.first;
200   auto filt = check_result.second;
201   if (err) return err;
202 
203   // Update the filter.
204   if (x) filt->set_x(*x);
205   if (y) filt->set_y(*y);
206   if (r) filt->set_r(*r);
207   return 0;
208 }
209 
210 } // namespace openmc
211