1 #include "openmc/tallies/filter_zernike.h"
2
3 #include <cmath>
4 #include <sstream>
5 #include <utility> // For pair
6
7 #include <fmt/core.h>
8 #include <gsl/gsl>
9
10 #include "openmc/capi.h"
11 #include "openmc/error.h"
12 #include "openmc/math_functions.h"
13 #include "openmc/xml_interface.h"
14
15 namespace openmc {
16
17 //==============================================================================
18 // ZernikeFilter implementation
19 //==============================================================================
20
21 void
from_xml(pugi::xml_node node)22 ZernikeFilter::from_xml(pugi::xml_node node)
23 {
24 set_order(std::stoi(get_node_value(node, "order")));
25 x_ = std::stod(get_node_value(node, "x"));
26 y_ = std::stod(get_node_value(node, "y"));
27 r_ = std::stod(get_node_value(node, "r"));
28 }
29
30 void
get_all_bins(const Particle & p,TallyEstimator estimator,FilterMatch & match) const31 ZernikeFilter::get_all_bins(const Particle& p, TallyEstimator estimator,
32 FilterMatch& match) const
33 {
34 // Determine the normalized (r,theta) coordinates.
35 double x = p.r().x - x_;
36 double y = p.r().y - y_;
37 double r = std::sqrt(x*x + y*y) / r_;
38 double theta = std::atan2(y, x);
39
40 if (r <= 1.0) {
41 // Compute and return the Zernike weights.
42 vector<double> zn(n_bins_);
43 calc_zn(order_, r, theta, zn.data());
44 for (int i = 0; i < n_bins_; i++) {
45 match.bins_.push_back(i);
46 match.weights_.push_back(zn[i]);
47 }
48 }
49 }
50
51 void
to_statepoint(hid_t filter_group) const52 ZernikeFilter::to_statepoint(hid_t filter_group) const
53 {
54 Filter::to_statepoint(filter_group);
55 write_dataset(filter_group, "order", order_);
56 write_dataset(filter_group, "x", x_);
57 write_dataset(filter_group, "y", y_);
58 write_dataset(filter_group, "r", r_);
59 }
60
61 std::string
text_label(int bin) const62 ZernikeFilter::text_label(int bin) const
63 {
64 Expects(bin >= 0 && bin < n_bins_);
65 for (int n = 0; n < order_+1; n++) {
66 int last = (n + 1) * (n + 2) / 2;
67 if (bin < last) {
68 int first = last - (n + 1);
69 int m = -n + (bin - first) * 2;
70 return fmt::format("Zernike expansion, Z{},{}", n, m);
71 }
72 }
73 UNREACHABLE();
74 }
75
76 void
set_order(int order)77 ZernikeFilter::set_order(int order)
78 {
79 if (order < 0) {
80 throw std::invalid_argument{"Zernike order must be non-negative."};
81 }
82 order_ = order;
83 n_bins_ = ((order+1) * (order+2)) / 2;
84 }
85
86 //==============================================================================
87 // ZernikeRadialFilter implementation
88 //==============================================================================
89
90 void
get_all_bins(const Particle & p,TallyEstimator estimator,FilterMatch & match) const91 ZernikeRadialFilter::get_all_bins(const Particle& p, TallyEstimator estimator,
92 FilterMatch& match) const
93 {
94 // Determine the normalized radius coordinate.
95 double x = p.r().x - x_;
96 double y = p.r().y - y_;
97 double r = std::sqrt(x*x + y*y) / r_;
98
99 if (r <= 1.0) {
100 // Compute and return the Zernike weights.
101 vector<double> zn(n_bins_);
102 calc_zn_rad(order_, r, zn.data());
103 for (int i = 0; i < n_bins_; i++) {
104 match.bins_.push_back(i);
105 match.weights_.push_back(zn[i]);
106 }
107 }
108 }
109
110 std::string
text_label(int bin) const111 ZernikeRadialFilter::text_label(int bin) const
112 {
113 return "Zernike expansion, Z" + std::to_string(2*bin) + ",0";
114 }
115
116 void
set_order(int order)117 ZernikeRadialFilter::set_order(int order)
118 {
119 ZernikeFilter::set_order(order);
120 n_bins_ = order / 2 + 1;
121 }
122
123 //==============================================================================
124 // C-API functions
125 //==============================================================================
126
127 std::pair<int, ZernikeFilter*>
check_zernike_filter(int32_t index)128 check_zernike_filter(int32_t index)
129 {
130 // Make sure this is a valid index to an allocated filter.
131 int err = verify_filter(index);
132 if (err) {
133 return {err, nullptr};
134 }
135
136 // Get a pointer to the filter and downcast.
137 const auto& filt_base = model::tally_filters[index].get();
138 auto* filt = dynamic_cast<ZernikeFilter*>(filt_base);
139
140 // Check the filter type.
141 if (!filt) {
142 set_errmsg("Not a Zernike filter.");
143 err = OPENMC_E_INVALID_TYPE;
144 }
145 return {err, filt};
146 }
147
148 extern "C" int
openmc_zernike_filter_get_order(int32_t index,int * order)149 openmc_zernike_filter_get_order(int32_t index, int* order)
150 {
151 // Check the filter.
152 auto check_result = check_zernike_filter(index);
153 auto err = check_result.first;
154 auto filt = check_result.second;
155 if (err) return err;
156
157 // Output the order.
158 *order = filt->order();
159 return 0;
160 }
161
162 extern "C" int
openmc_zernike_filter_get_params(int32_t index,double * x,double * y,double * r)163 openmc_zernike_filter_get_params(int32_t index, double* x, double* y,
164 double* r)
165 {
166 // Check the filter.
167 auto check_result = check_zernike_filter(index);
168 auto err = check_result.first;
169 auto filt = check_result.second;
170 if (err) return err;
171
172 // Output the params.
173 *x = filt->x();
174 *y = filt->y();
175 *r = filt->r();
176 return 0;
177 }
178
179 extern "C" int
openmc_zernike_filter_set_order(int32_t index,int order)180 openmc_zernike_filter_set_order(int32_t index, int order)
181 {
182 // Check the filter.
183 auto check_result = check_zernike_filter(index);
184 auto err = check_result.first;
185 auto filt = check_result.second;
186 if (err) return err;
187
188 // Update the filter.
189 filt->set_order(order);
190 return 0;
191 }
192
193 extern "C" int
openmc_zernike_filter_set_params(int32_t index,const double * x,const double * y,const double * r)194 openmc_zernike_filter_set_params(int32_t index, const double* x,
195 const double* y, const double* r)
196 {
197 // Check the filter.
198 auto check_result = check_zernike_filter(index);
199 auto err = check_result.first;
200 auto filt = check_result.second;
201 if (err) return err;
202
203 // Update the filter.
204 if (x) filt->set_x(*x);
205 if (y) filt->set_y(*y);
206 if (r) filt->set_r(*r);
207 return 0;
208 }
209
210 } // namespace openmc
211