1from numbers import Real
2from math import exp, erf, pi, sqrt
3from copy import deepcopy
4import warnings
5
6import os
7import h5py
8import pickle
9import numpy as np
10from scipy.signal import find_peaks
11import matplotlib.pyplot as plt
12
13import openmc.checkvalue as cv
14from ..exceptions import DataError
15from ..mixin import EqualityMixin
16from . import WMP_VERSION, WMP_VERSION_MAJOR
17from .data import K_BOLTZMANN
18from .neutron import IncidentNeutron
19from .resonance import ResonanceRange
20
21
22# Constants that determine which value to access
23_MP_EA = 0       # Pole
24
25# Residue indices
26_MP_RS = 1       # Residue scattering
27_MP_RA = 2       # Residue absorption
28_MP_RF = 3       # Residue fission
29
30# Polynomial fit indices
31_FIT_S = 0       # Scattering
32_FIT_A = 1       # Absorption
33_FIT_F = 2       # Fission
34
35# Upper temperature limit (K)
36TEMPERATURE_LIMIT = 3000
37
38# Logging control
39DETAILED_LOGGING = 2
40
41
42def _faddeeva(z):
43    r"""Evaluate the complex Faddeeva function.
44
45    Technically, the value we want is given by the equation:
46
47    .. math::
48        w(z) = \frac{i}{\pi} \int_{-\infty}^{\infty} \frac{1}{z - t}
49        \exp(-t^2) \text{d}t
50
51    as shown in Equation 63 from Hwang, R. N. "A rigorous pole
52    representation of multilevel cross sections and its practical
53    applications." Nuclear Science and Engineering 96.3 (1987): 192-209.
54
55    The :func:`scipy.special.wofz` function evaluates
56    :math:`w(z) = \exp(-z^2) \text{erfc}(-iz)`. These two forms of the Faddeeva
57    function are related by a transformation.
58
59    If we call the integral form :math:`w_\text{int}`, and the function form
60    :math:`w_\text{fun}`:
61
62    .. math::
63        w_\text{int}(z) =
64        \begin{cases}
65            w_\text{fun}(z) & \text{for } \text{Im}(z) > 0\\
66            -w_\text{fun}(z^*)^* & \text{for } \text{Im}(z) < 0
67        \end{cases}
68
69    Parameters
70    ----------
71    z : complex
72        Argument to the Faddeeva function.
73
74    Returns
75    -------
76    complex
77        :math:`\frac{i}{\pi} \int_{-\infty}^{\infty} \frac{1}{z - t} \exp(-t^2)
78        \text{d}t`
79
80    """
81    from scipy.special import wofz
82    if np.angle(z) > 0:
83        return wofz(z)
84    else:
85        return -np.conj(wofz(z.conjugate()))
86
87
88def _broaden_wmp_polynomials(E, dopp, n):
89    r"""Evaluate Doppler-broadened windowed multipole curvefit.
90
91    The curvefit is a polynomial of the form :math:`\frac{a}{E}
92    + \frac{b}{\sqrt{E}} + c + d \sqrt{E} + \ldots`
93
94    Parameters
95    ----------
96    E : float
97        Energy to evaluate at.
98    dopp : float
99        sqrt(atomic weight ratio / kT) in units of eV.
100    n : int
101        Number of components to the polynomial.
102
103    Returns
104    -------
105    np.ndarray
106        The value of each Doppler-broadened curvefit polynomial term.
107
108    """
109    sqrtE = sqrt(E)
110    beta = sqrtE * dopp
111    half_inv_dopp2 = 0.5 / dopp**2
112    quarter_inv_dopp4 = half_inv_dopp2**2
113
114    if beta > 6.0:
115        # Save time, ERF(6) is 1 to machine precision.
116        # beta/sqrtpi*exp(-beta**2) is also approximately 1 machine epsilon.
117        erf_beta = 1.0
118        exp_m_beta2 = 0.0
119    else:
120        erf_beta = erf(beta)
121        exp_m_beta2 = exp(-beta**2)
122
123    # Assume that, for sure, we'll use a second order (1/E, 1/V, const)
124    # fit, and no less.
125
126    factors = np.zeros(n)
127
128    factors[0] = erf_beta / E
129    factors[1] = 1.0 / sqrtE
130    factors[2] = (factors[0] * (half_inv_dopp2 + E)
131                  + exp_m_beta2 / (beta * sqrt(pi)))
132
133    # Perform recursive broadening of high order components. range(1, n-2)
134    # replaces a do i = 1, n-3.  All indices are reduced by one due to the
135    # 1-based vs. 0-based indexing.
136    for i in range(1, n-2):
137        if i != 1:
138            factors[i+2] = (-factors[i-2] * (i - 1.0) * i * quarter_inv_dopp4
139                + factors[i] * (E + (1.0 + 2.0 * i) * half_inv_dopp2))
140        else:
141            factors[i+2] = factors[i]*(E + (1.0 + 2.0 * i) * half_inv_dopp2)
142
143    return factors
144
145
146def _vectfit_xs(energy, ce_xs, mts, rtol=1e-3, atol=1e-5, orders=None,
147                n_vf_iter=30, log=False, path_out=None):
148    """Convert point-wise cross section to multipole data via vector fitting.
149
150    Parameters
151    ----------
152    energy : np.ndarray
153        Energy array
154    ce_xs : np.ndarray
155        Point-wise cross sections to be fitted, with shape (number of reactions,
156        number of energy points)
157    mts : Iterable of int
158        Reaction list
159    rtol : float, optional
160        Relative error tolerance
161    atol : float, optional
162        Absolute error tolerance
163    orders : Iterable of int, optional
164        A list of orders (number of poles) to be searched
165    n_vf_iter : int, optional
166        Number of maximum VF iterations
167    log : bool or int, optional
168        Whether to print running logs (use int for verbosity control)
169    path_out : str, optional
170        Path to save the figures to show discrepancies between the original and
171        fitted cross sections for different reactions
172
173    Returns
174    -------
175    tuple
176        (poles, residues)
177
178    """
179
180    # import vectfit package: https://github.com/liangjg/vectfit
181    import vectfit as vf
182
183    ne = energy.size
184    nmt = len(mts)
185    if ce_xs.shape != (nmt, ne):
186        raise ValueError('Inconsistent cross section data.')
187
188    # construct test data: interpolate xs with finer grids
189    n_finer = 10
190    ne_test = (ne - 1)*n_finer + 1
191    test_energy = np.interp(np.arange(ne_test),
192                            np.arange(ne_test, step=n_finer), energy)
193    test_energy[[0, -1]] = energy[[0, -1]]  # avoid numerical issue
194    test_xs_ref = np.zeros((nmt, ne_test))
195    for i in range(nmt):
196        test_xs_ref[i] = np.interp(test_energy, energy, ce_xs[i])
197
198    if log:
199        print("  energy: {:.3e} to {:.3e} eV ({} points)".format(
200              energy[0], energy[-1], ne))
201        print("  error tolerance: rtol={}, atol={}".format(rtol, atol))
202
203    # transform xs (sigma) and energy (E) to f (sigma*E) and s (sqrt(E)) to be
204    # compatible with the multipole representation
205    f = ce_xs * energy
206    s = np.sqrt(energy)
207    test_s = np.sqrt(test_energy)
208
209    # inverse weighting is used for minimizing the relative deviation instead of
210    # absolute deviation in vector fitting
211    with np.errstate(divide='ignore'):
212        weight = 1.0/f
213
214    # avoid too large weights which will harm the fitting accuracy
215    min_cross_section = 1e-7
216    for i in range(nmt):
217        if np.all(ce_xs[i] <= min_cross_section):
218            weight[i] = 1.0
219        elif np.any(ce_xs[i] <= min_cross_section):
220            weight[i, ce_xs[i] <= min_cross_section] = \
221               max(weight[i, ce_xs[i] > min_cross_section])
222
223    # detect peaks (resonances) and determine VF order search range
224    peaks, _ = find_peaks(ce_xs[0] + ce_xs[1])
225    n_peaks = peaks.size
226    if orders is not None:
227        # make sure orders are even integers
228        orders = list(set([int(i/2)*2 for i in orders if i >= 2]))
229    else:
230        lowest_order = max(2, 2*n_peaks)
231        highest_order = max(200, 4*n_peaks)
232        orders = list(range(lowest_order, highest_order + 1, 2))
233
234    if log:
235        print("Found {} peaks".format(n_peaks))
236        print("Fitting orders from {} to {}".format(orders[0], orders[-1]))
237
238    # perform VF with increasing orders
239    found_ideal = False
240    n_discarded = 0  # for accelation, number of discarded searches
241    best_quality = best_ratio = -np.inf
242    for i, order in enumerate(orders):
243        if log:
244            print("Order={}({}/{})".format(order, i, len(orders)))
245        # initial guessed poles
246        poles_r = np.linspace(s[0], s[-1], order//2)
247        poles = poles_r + poles_r*0.01j
248        poles = np.sort(np.append(poles, np.conj(poles)))
249
250        found_better = False
251        # fitting iteration
252        for i_vf in range(n_vf_iter):
253            if log >= DETAILED_LOGGING:
254                print("VF iteration {}/{}".format(i_vf + 1, n_vf_iter))
255
256            # call vf
257            poles, residues, cf, f_fit, rms = vf.vectfit(f, s, poles, weight)
258
259            # convert real pole to conjugate pairs
260            n_real_poles = 0
261            new_poles = []
262            for p in poles:
263                p_r, p_i = np.real(p), np.imag(p)
264                if (s[0] <= p_r <= s[-1]) and p_i == 0.:
265                    new_poles += [p_r+p_r*0.01j, p_r-p_r*0.01j]
266                    n_real_poles += 1
267                else:
268                    new_poles += [p]
269            new_poles = np.array(new_poles)
270            # re-calculate residues if poles changed
271            if n_real_poles > 0:
272                if log >= DETAILED_LOGGING:
273                    print("  # real poles: {}".format(n_real_poles))
274                new_poles, residues, cf, f_fit, rms = \
275                      vf.vectfit(f, s, new_poles, weight, skip_pole=True)
276
277            # assess the result on test grid
278            test_xs = vf.evaluate(test_s, new_poles, residues) / test_energy
279            abserr = np.abs(test_xs - test_xs_ref)
280            with np.errstate(invalid='ignore', divide='ignore'):
281                relerr = abserr / test_xs_ref
282                if np.any(np.isnan(abserr)):
283                    maxre, ratio, ratio2 = np.inf, -np.inf, -np.inf
284                elif np.all(abserr <= atol):
285                    maxre, ratio, ratio2 = 0., 1., 1.
286                else:
287                    maxre = np.max(relerr[abserr > atol])
288                    ratio = np.sum((relerr < rtol) | (abserr < atol)) / relerr.size
289                    ratio2 = np.sum((relerr < 10*rtol) | (abserr < atol)) / relerr.size
290
291            # define a metric for choosing the best fitting results
292            # basically, it is preferred to have more points within accuracy
293            # tolerance, smaller maximum deviation and fewer poles
294            #TODO: improve the metric with clearer basis
295            quality = ratio + ratio2 - min(0.1*maxre, 1) - 0.001*new_poles.size
296
297            if np.any(test_xs < -atol):
298                quality = -np.inf
299
300            if log >= DETAILED_LOGGING:
301                print("  # poles: {}".format(new_poles.size))
302                print("  Max relative error: {:.3f}%".format(maxre*100))
303                print("  Satisfaction: {:.1f}%, {:.1f}%".format(ratio*100, ratio2*100))
304                print("  Quality: {:.2f}".format(quality))
305
306            if quality > best_quality:
307                if log >= DETAILED_LOGGING:
308                    print("  Best so far!")
309                found_better = True
310                best_quality, best_ratio = quality, ratio
311                best_poles, best_residues = new_poles, residues
312                best_test_xs, best_relerr = test_xs, relerr
313                if best_ratio >= 1.0:
314                    if log:
315                        print("Found ideal results. Stop!")
316                    found_ideal = True
317                    break
318            else:
319                if log >= DETAILED_LOGGING:
320                    print("  Discarded!")
321
322        if found_ideal:
323            break
324
325        # acceleration
326        if found_better:
327            n_discarded = 0
328        else:
329            if order > max(2*n_peaks, 50) and best_ratio > 0.7:
330                n_discarded += 1
331                if n_discarded >= 10 or (n_discarded >= 5 and best_ratio > 0.9):
332                    if log >= DETAILED_LOGGING:
333                        print("Couldn't get better results. Stop!")
334                    break
335
336    # merge conjugate poles
337    real_idx = []
338    conj_idx = []
339    found_conj = False
340    for i, p in enumerate(best_poles):
341        if found_conj:
342            found_conj = False
343            continue
344        if np.imag(p) == 0.:
345            real_idx.append(i)
346        else:
347            if i < best_poles.size and np.conj(p) == best_poles[i + 1]:
348                found_conj = True
349                conj_idx.append(i)
350            else:
351                raise RuntimeError("Complex poles are not conjugate!")
352    if log:
353        print("Found {} real poles and {} conjugate complex pairs.".format(
354               len(real_idx), len(conj_idx)))
355    mp_poles = best_poles[real_idx + conj_idx]
356    mp_residues = np.concatenate((best_residues[:, real_idx],
357                                  best_residues[:, conj_idx]*2), axis=1)/1j
358    if log:
359        print("Final number of poles: {}".format(mp_poles.size))
360
361    if path_out:
362        if not os.path.exists(path_out):
363            os.makedirs(path_out)
364        for i, mt in enumerate(mts):
365            if not test_xs_ref[i].any():
366                continue
367            fig, ax1 = plt.subplots()
368            lns1 = ax1.loglog(test_energy, test_xs_ref[i], 'g', label="ACE xs")
369            lns2 = ax1.loglog(test_energy, best_test_xs[i], 'b', label="VF xs")
370            ax2 = ax1.twinx()
371            lns3 = ax2.loglog(test_energy, best_relerr[i], 'r',
372                              label="Relative error", alpha=0.5)
373            lns = lns1 + lns2 + lns3
374            labels = [l.get_label() for l in lns]
375            ax1.legend(lns, labels, loc='best')
376            ax1.set_xlabel('energy (eV)')
377            ax1.set_ylabel('cross section (b)', color='b')
378            ax1.tick_params('y', colors='b')
379            ax2.set_ylabel('relative error', color='r')
380            ax2.tick_params('y', colors='r')
381
382            plt.title("MT {} vector fitted with {} poles".format(mt, mp_poles.size))
383            fig.tight_layout()
384            fig_file = os.path.join(path_out, "{:.0f}-{:.0f}_MT{}.png".format(
385                                    energy[0], energy[-1], mt))
386            plt.savefig(fig_file)
387            plt.close()
388            if log:
389                print("Saved figure: {}".format(fig_file))
390
391    return (mp_poles, mp_residues)
392
393
394def vectfit_nuclide(endf_file, njoy_error=5e-4, vf_pieces=None,
395                    log=False, path_out=None, mp_filename=None, **kwargs):
396    r"""Generate multipole data for a nuclide from ENDF.
397
398    Parameters
399    ----------
400    endf_file : str
401        Path to ENDF evaluation
402    njoy_error : float, optional
403        Fractional error tolerance for processing point-wise data with NJOY
404    vf_pieces : integer, optional
405        Number of equal-in-momentum spaced energy pieces for data fitting
406    log : bool or int, optional
407        Whether to print running logs (use int for verbosity control)
408    path_out : str, optional
409        Path to write out mutipole data file and vector fitting figures
410    mp_filename : str, optional
411        File name to write out multipole data
412    **kwargs
413        Keyword arguments passed to :func:`openmc.data.multipole._vectfit_xs`
414
415    Returns
416    -------
417    mp_data
418        Dictionary containing necessary multipole data of the nuclide
419
420    """
421
422    # ======================================================================
423    # PREPARE POINT-WISE XS
424
425    # make 0K ACE data using njoy
426    if log:
427        print("Running NJOY to get 0K point-wise data (error={})...".format(njoy_error))
428
429    nuc_ce = IncidentNeutron.from_njoy(endf_file, temperatures=[0.0],
430             error=njoy_error, broadr=False, heatr=False, purr=False)
431
432    if log:
433        print("Parsing cross sections within resolved resonance range...")
434
435    # Determine upper energy: the lower of RRR upper bound and first threshold
436    endf_res = IncidentNeutron.from_endf(endf_file).resonances
437    if hasattr(endf_res, 'resolved') and \
438       hasattr(endf_res.resolved, 'energy_max') and \
439       type(endf_res.resolved) is not ResonanceRange:
440        E_max = endf_res.resolved.energy_max
441    elif hasattr(endf_res, 'unresolved') and \
442         hasattr(endf_res.unresolved, 'energy_min'):
443        E_max = endf_res.unresolved.energy_min
444    else:
445        E_max = nuc_ce.energy['0K'][-1]
446    E_max_idx = np.searchsorted(nuc_ce.energy['0K'], E_max, side='right') - 1
447    for mt in nuc_ce.reactions:
448        if hasattr(nuc_ce.reactions[mt].xs['0K'], '_threshold_idx'):
449            threshold_idx = nuc_ce.reactions[mt].xs['0K']._threshold_idx
450            if 0 < threshold_idx < E_max_idx:
451                E_max_idx = threshold_idx
452
453    # parse energy and cross sections
454    energy = nuc_ce.energy['0K'][:E_max_idx + 1]
455    E_min, E_max = energy[0], energy[-1]
456    n_points = energy.size
457    total_xs = nuc_ce[1].xs['0K'](energy)
458    elastic_xs = nuc_ce[2].xs['0K'](energy)
459
460    try:
461        absorption_xs = nuc_ce[27].xs['0K'](energy)
462    except KeyError:
463        absorption_xs = np.zeros_like(total_xs)
464
465    fissionable = False
466    try:
467        fission_xs = nuc_ce[18].xs['0K'](energy)
468        fissionable = True
469    except KeyError:
470        pass
471
472    # make vectors
473    if fissionable:
474        ce_xs = np.vstack((elastic_xs, absorption_xs, fission_xs))
475        mts = [2, 27, 18]
476    else:
477        ce_xs = np.vstack((elastic_xs, absorption_xs))
478        mts = [2, 27]
479
480    if log:
481        print("  MTs: {}".format(mts))
482        print("  Energy range: {:.3e} to {:.3e} eV ({} points)".format(
483              E_min, E_max, n_points))
484
485    # ======================================================================
486    # PERFORM VECTOR FITTING
487
488    if vf_pieces is None:
489        # divide into pieces for complex nuclides
490        peaks, _ = find_peaks(total_xs)
491        n_peaks = peaks.size
492        if n_peaks > 200 or n_points > 30000 or n_peaks * n_points > 100*10000:
493            vf_pieces = max(5, n_peaks // 50,  n_points // 2000)
494        else:
495            vf_pieces = 1
496    piece_width = (sqrt(E_max) - sqrt(E_min)) / vf_pieces
497
498    alpha = nuc_ce.atomic_weight_ratio/(K_BOLTZMANN*TEMPERATURE_LIMIT)
499
500    poles, residues = [], []
501    # VF piece by piece
502    for i_piece in range(vf_pieces):
503        if log:
504            print("Vector fitting piece {}/{}...".format(i_piece + 1, vf_pieces))
505        # start E of this piece
506        e_bound = (sqrt(E_min) + piece_width*(i_piece-0.5))**2
507        if i_piece == 0 or sqrt(alpha*e_bound) < 4.0:
508            e_start = E_min
509            e_start_idx = 0
510        else:
511            e_start = max(E_min, (sqrt(alpha*e_bound) - 4.0)**2/alpha)
512            e_start_idx = np.searchsorted(energy, e_start, side='right') - 1
513        # end E of this piece
514        e_bound = (sqrt(E_min) + piece_width*(i_piece + 1))**2
515        e_end = min(E_max, (sqrt(alpha*e_bound) + 4.0)**2/alpha)
516        e_end_idx = np.searchsorted(energy, e_end, side='left') + 1
517        e_idx = range(e_start_idx, min(e_end_idx + 1, n_points))
518
519        p, r = _vectfit_xs(energy[e_idx], ce_xs[:, e_idx], mts, log=log,
520                           path_out=path_out, **kwargs)
521
522        poles.append(p)
523        residues.append(r)
524
525    # collect multipole data into a dictionary
526    mp_data = {"name": nuc_ce.name,
527               "AWR": nuc_ce.atomic_weight_ratio,
528               "E_min": E_min,
529               "E_max": E_max,
530               "poles": poles,
531               "residues": residues}
532
533    # dump multipole data to file
534    if path_out:
535        if not os.path.exists(path_out):
536            os.makedirs(path_out)
537        if not mp_filename:
538            mp_filename = "{}_mp.pickle".format(nuc_ce.name)
539        mp_filename = os.path.join(path_out, mp_filename)
540        with open(mp_filename, 'wb') as f:
541            pickle.dump(mp_data, f)
542        if log:
543            print("Dumped multipole data to file: {}".format(mp_filename))
544
545    return mp_data
546
547
548def _windowing(mp_data, n_cf, rtol=1e-3, atol=1e-5, n_win=None, spacing=None,
549               log=False):
550    """Generate windowed multipole library from multipole data with specific
551        settings of window size, curve fit order, etc.
552
553    Parameters
554    ----------
555    mp_data : dict
556        Multipole data
557    n_cf : int
558        Curve fitting order
559    rtol : float, optional
560        Maximum relative error tolerance
561    atol : float, optional
562        Minimum absolute error tolerance
563    n_win : int, optional
564        Number of equal-in-mementum spaced energy windows
565    spacing : float, optional
566        Inner window spacing (sqrt energy space)
567    log : bool or int, optional
568        Whether to print running logs (use int for verbosity control)
569
570    Returns
571    -------
572    openmc.data.WindowedMultipole
573        Resonant cross sections represented in the windowed multipole
574        format.
575
576    """
577
578    # import vectfit package: https://github.com/liangjg/vectfit
579    import vectfit as vf
580
581    # unpack multipole data
582    name = mp_data["name"]
583    awr = mp_data["AWR"]
584    E_min = mp_data["E_min"]
585    E_max = mp_data["E_max"]
586    mp_poles = mp_data["poles"]
587    mp_residues = mp_data["residues"]
588
589    n_pieces = len(mp_poles)
590    piece_width = (sqrt(E_max) - sqrt(E_min)) / n_pieces
591    alpha = awr / (K_BOLTZMANN*TEMPERATURE_LIMIT)
592
593    # determine window size
594    if n_win is None:
595        if spacing is not None:
596            # ensure the windows are within the multipole energy range
597            n_win = int((sqrt(E_max) - sqrt(E_min)) / spacing)
598            E_max = (sqrt(E_min) + n_win*spacing)**2
599        else:
600            n_win = 1000
601    # inner window size
602    spacing = (sqrt(E_max) - sqrt(E_min)) / n_win
603    # make sure inner window size is smaller than energy piece size
604    if spacing > piece_width:
605        raise ValueError('Window spacing cannot be larger than piece spacing.')
606
607    if log:
608        print("Windowing:")
609        print("  config: # windows={}, spacing={}, CF order={}".format(
610               n_win, spacing, n_cf))
611        print("  error tolerance: rtol={}, atol={}".format(rtol, atol))
612
613    # sort poles (and residues) by the real component of the pole
614    for ip in range(n_pieces):
615        indices = mp_poles[ip].argsort()
616        mp_poles[ip] = mp_poles[ip][indices]
617        mp_residues[ip] = mp_residues[ip][:, indices]
618
619    # initialize an array to record whether each pole is used or not
620    poles_unused = [np.ones_like(p, dtype=int) for p in mp_poles]
621
622    # optimize the windows: the goal is to find the least set of significant
623    # consecutive poles and curve fit coefficients to reproduce cross section
624    win_data = []
625    for iw in range(n_win):
626        if log >= DETAILED_LOGGING:
627            print("Processing window {}/{}...".format(iw + 1, n_win))
628
629        # inner window boundaries
630        inbegin = sqrt(E_min) + spacing * iw
631        inend = inbegin + spacing
632        incenter = (inbegin + inend) / 2.0
633        # extend window energy range for Doppler broadening
634        if iw == 0 or sqrt(alpha)*inbegin < 4.0:
635            e_start = inbegin**2
636        else:
637            e_start = max(E_min, (sqrt(alpha)*inbegin - 4.0)**2/alpha)
638        e_end = min(E_max, (sqrt(alpha)*inend + 4.0)**2/alpha)
639
640        # locate piece and relevant poles
641        i_piece = min(n_pieces - 1, int((inbegin - sqrt(E_min))/piece_width + 0.5))
642        poles, residues = mp_poles[i_piece], mp_residues[i_piece]
643        n_poles = poles.size
644
645        # generate energy points for fitting: equally spaced in momentum
646        n_points = min(max(100, int((e_end - e_start)*4)), 10000)
647        energy_sqrt = np.linspace(np.sqrt(e_start), np.sqrt(e_end), n_points)
648        energy = energy_sqrt**2
649
650        # reference xs from multipole form, note the residue terms in the
651        # multipole and vector fitting representations differ by a 1j
652        xs_ref = vf.evaluate(energy_sqrt, poles, residues*1j) / energy
653
654        # curve fit matrix
655        matrix = np.vstack([energy**(0.5*i - 1) for i in range(n_cf + 1)]).T
656
657        # start from 0 poles, initialize pointers to the center nearest pole
658        center_pole_ind = np.argmin((np.fabs(poles.real - incenter)))
659        lp = rp = center_pole_ind
660        while True:
661            if log >= DETAILED_LOGGING:
662                print("Trying poles {} to {}".format(lp, rp))
663
664            # calculate the cross sections contributed by the windowed poles
665            if rp > lp:
666                xs_wp = vf.evaluate(energy_sqrt, poles[lp:rp],
667                                    residues[:, lp:rp]*1j) / energy
668            else:
669                xs_wp = np.zeros_like(xs_ref)
670
671            # do least square curve fit on the remains
672            coefs = np.linalg.lstsq(matrix, (xs_ref - xs_wp).T, rcond=None)[0]
673            xs_fit = (matrix @ coefs).T
674
675            # assess the result
676            abserr = np.abs(xs_fit + xs_wp - xs_ref)
677            with np.errstate(invalid='ignore', divide='ignore'):
678                relerr = abserr / xs_ref
679            if not np.any(np.isnan(abserr)):
680                re = relerr[abserr > atol]
681                if re.size == 0 or np.all(re <= rtol) or \
682                   (re.max() <= 2*rtol and (re > rtol).sum() <= 0.01*relerr.size) or \
683                   (iw == 0 and np.all(relerr.mean(axis=1) <= rtol)):
684                    # meet tolerances
685                    if log >= DETAILED_LOGGING:
686                        print("Accuracy satisfied.")
687                    break
688
689            # we expect pure curvefit will succeed for the first window
690            # TODO: find the energy boundary below which no poles are allowed
691            if iw == 0:
692                raise RuntimeError('Pure curvefit failed for the first window!')
693
694            # try to include one more pole (next center nearest)
695            if rp >= n_poles:
696                lp -= 1
697            elif lp <= 0 or poles[rp] - incenter <= incenter - poles[lp - 1]:
698                rp += 1
699            else:
700                lp -= 1
701
702        # save data for this window
703        win_data.append((i_piece, lp, rp, coefs))
704
705        # mark the windowed poles as used poles
706        poles_unused[i_piece][lp:rp] = 0
707
708    # flatten and shrink by removing unused poles
709    data = []  # used poles and residues
710    for ip in range(n_pieces):
711        used = (poles_unused[ip] == 0)
712        # stack poles and residues for library format
713        data.append(np.vstack([mp_poles[ip][used], mp_residues[ip][:, used]]).T)
714    # stack poles/residues in sequence vertically
715    data = np.vstack(data)
716    # new start/end pole indices
717    windows = []
718    curvefit = []
719    for iw in range(n_win):
720        ip, lp, rp, coefs = win_data[iw]
721        # adjust indices and change to 1-based for the library format
722        n_prev_poles = sum([poles_unused[i].size for i in range(ip)])
723        n_unused = sum([(poles_unused[i] == 1).sum() for i in range(ip)]) + \
724                  (poles_unused[ip][:lp] == 1).sum()
725        lp += n_prev_poles - n_unused + 1
726        rp += n_prev_poles - n_unused
727        windows.append([lp, rp])
728        curvefit.append(coefs)
729
730    # construct the WindowedMultipole object
731    wmp = WindowedMultipole(name)
732    wmp.spacing = spacing
733    wmp.sqrtAWR = sqrt(awr)
734    wmp.E_min = E_min
735    wmp.E_max = E_max
736    wmp.data = data
737    wmp.windows = np.asarray(windows)
738    wmp.curvefit = np.asarray(curvefit)
739    # TODO: check if Doppler brodening of the polynomial curvefit is negligible
740    wmp.broaden_poly = np.ones((n_win,), dtype=bool)
741
742    return wmp
743
744
745class WindowedMultipole(EqualityMixin):
746    """Resonant cross sections represented in the windowed multipole format.
747
748    Parameters
749    ----------
750    name : str
751        Name of the nuclide using the GND naming convention
752
753    Attributes
754    ----------
755    name : str
756        Name of the nuclide using the GND naming convention
757    spacing : float
758        The width of each window in sqrt(E)-space.  For example, the frst window
759        will end at (sqrt(E_min) + spacing)**2 and the second window at
760        (sqrt(E_min) + 2*spacing)**2.
761    sqrtAWR : float
762        Square root of the atomic weight ratio of the target nuclide.
763    E_min : float
764        Lowest energy in eV the library is valid for.
765    E_max : float
766        Highest energy in eV the library is valid for.
767    data : np.ndarray
768        A 2D array of complex poles and residues.  data[i, 0] gives the energy
769        at which pole i is located.  data[i, 1:] gives the residues associated
770        with the i-th pole.  There are 3 residues, one each for the scattering,
771        absorption, and fission channels.
772    windows : np.ndarray
773        A 2D array of Integral values.  windows[i, 0] - 1 is the index of the
774        first pole in window i. windows[i, 1] - 1 is the index of the last pole
775        in window i.
776    broaden_poly : np.ndarray
777        A 1D array of boolean values indicating whether or not the polynomial
778        curvefit in that window should be Doppler broadened.
779    curvefit : np.ndarray
780        A 3D array of Real curvefit polynomial coefficients.  curvefit[i, 0, :]
781        gives coefficients for the scattering cross section in window i.
782        curvefit[i, 1, :] gives absorption coefficients and curvefit[i, 2, :]
783        gives fission coefficients.  The polynomial terms are increasing powers
784        of sqrt(E) starting with 1/E e.g:
785        a/E + b/sqrt(E) + c + d sqrt(E) + ...
786
787    """
788    def __init__(self, name):
789        self.name = name
790        self.spacing = None
791        self.sqrtAWR = None
792        self.E_min = None
793        self.E_max = None
794        self.data = None
795        self.windows = None
796        self.broaden_poly = None
797        self.curvefit = None
798
799    @property
800    def name(self):
801        return self._name
802
803    @property
804    def fit_order(self):
805        return self.curvefit.shape[1] - 1
806
807    @property
808    def fissionable(self):
809        return self.data.shape[1] == 4
810
811    @property
812    def n_poles(self):
813        return self.data.shape[0]
814
815    @property
816    def n_windows(self):
817        return self.windows.shape[0]
818
819    @property
820    def poles_per_window(self):
821        return (self.windows[:, 1] - self.windows[:, 0] + 1).mean()
822
823    @property
824    def spacing(self):
825        return self._spacing
826
827    @property
828    def sqrtAWR(self):
829        return self._sqrtAWR
830
831    @property
832    def E_min(self):
833        return self._E_min
834
835    @property
836    def E_max(self):
837        return self._E_max
838
839    @property
840    def data(self):
841        return self._data
842
843    @property
844    def windows(self):
845        return self._windows
846
847    @property
848    def broaden_poly(self):
849        return self._broaden_poly
850
851    @property
852    def curvefit(self):
853        return self._curvefit
854
855    @name.setter
856    def name(self, name):
857        cv.check_type('name', name, str)
858        self._name = name
859
860    @spacing.setter
861    def spacing(self, spacing):
862        if spacing is not None:
863            cv.check_type('spacing', spacing, Real)
864            cv.check_greater_than('spacing', spacing, 0.0, equality=False)
865        self._spacing = spacing
866
867    @sqrtAWR.setter
868    def sqrtAWR(self, sqrtAWR):
869        if sqrtAWR is not None:
870            cv.check_type('sqrtAWR', sqrtAWR, Real)
871            cv.check_greater_than('sqrtAWR', sqrtAWR, 0.0, equality=False)
872        self._sqrtAWR = sqrtAWR
873
874    @E_min.setter
875    def E_min(self, E_min):
876        if E_min is not None:
877            cv.check_type('E_min', E_min, Real)
878            cv.check_greater_than('E_min', E_min, 0.0, equality=True)
879        self._E_min = E_min
880
881    @E_max.setter
882    def E_max(self, E_max):
883        if E_max is not None:
884            cv.check_type('E_max', E_max, Real)
885            cv.check_greater_than('E_max', E_max, 0.0, equality=False)
886        self._E_max = E_max
887
888    @data.setter
889    def data(self, data):
890        if data is not None:
891            cv.check_type('data', data, np.ndarray)
892            if len(data.shape) != 2:
893                raise ValueError('Multipole data arrays must be 2D')
894            if data.shape[1] not in (3, 4):
895                raise ValueError(
896                     'data.shape[1] must be 3 or 4. One value for the pole.'
897                     ' One each for the scattering and absorption residues. '
898                     'Possibly one more for a fission residue.')
899            if not np.issubdtype(data.dtype, np.complexfloating):
900                raise TypeError('Multipole data arrays must be complex dtype')
901        self._data = data
902
903    @windows.setter
904    def windows(self, windows):
905        if windows is not None:
906            cv.check_type('windows', windows, np.ndarray)
907            if len(windows.shape) != 2:
908                raise ValueError('Multipole windows arrays must be 2D')
909            if not np.issubdtype(windows.dtype, np.integer):
910                raise TypeError('Multipole windows arrays must be integer'
911                                ' dtype')
912        self._windows = windows
913
914    @broaden_poly.setter
915    def broaden_poly(self, broaden_poly):
916        if broaden_poly is not None:
917            cv.check_type('broaden_poly', broaden_poly, np.ndarray)
918            if len(broaden_poly.shape) != 1:
919                raise ValueError('Multipole broaden_poly arrays must be 1D')
920            if not np.issubdtype(broaden_poly.dtype, np.bool_):
921                raise TypeError('Multipole broaden_poly arrays must be boolean'
922                                ' dtype')
923        self._broaden_poly = broaden_poly
924
925    @curvefit.setter
926    def curvefit(self, curvefit):
927        if curvefit is not None:
928            cv.check_type('curvefit', curvefit, np.ndarray)
929            if len(curvefit.shape) != 3:
930                raise ValueError('Multipole curvefit arrays must be 3D')
931            if curvefit.shape[2] not in (2, 3):  # sig_s, sig_a (maybe sig_f)
932                raise ValueError('The third dimension of multipole curvefit'
933                                 ' arrays must have a length of 2 or 3')
934            if not np.issubdtype(curvefit.dtype, np.floating):
935                raise TypeError('Multipole curvefit arrays must be float dtype')
936        self._curvefit = curvefit
937
938    @classmethod
939    def from_hdf5(cls, group_or_filename):
940        """Construct a WindowedMultipole object from an HDF5 group or file.
941
942        Parameters
943        ----------
944        group_or_filename : h5py.Group or str
945            HDF5 group containing multipole data. If given as a string, it is
946            assumed to be the filename for the HDF5 file, and the first group is
947            used to read from.
948
949        Returns
950        -------
951        openmc.data.WindowedMultipole
952            Resonant cross sections represented in the windowed multipole
953            format.
954
955        """
956
957        if isinstance(group_or_filename, h5py.Group):
958            group = group_or_filename
959            need_to_close = False
960        else:
961            h5file = h5py.File(str(group_or_filename), 'r')
962            need_to_close = True
963
964            # Make sure version matches
965            if 'version' in h5file.attrs:
966                major, minor = h5file.attrs['version']
967                if major != WMP_VERSION_MAJOR:
968                    raise DataError(
969                        'WMP data format uses version {}. {} whereas your '
970                        'installation of the OpenMC Python API expects version '
971                        '{}.x.'.format(major, minor, WMP_VERSION_MAJOR))
972            else:
973                raise DataError(
974                    'WMP data does not indicate a version. Your installation of '
975                    'the OpenMC Python API expects version {}.x data.'
976                    .format(WMP_VERSION_MAJOR))
977
978            group = list(h5file.values())[0]
979
980        name = group.name[1:]
981        out = cls(name)
982
983        # Read scalars.
984
985        out.spacing = group['spacing'][()]
986        out.sqrtAWR = group['sqrtAWR'][()]
987        out.E_min = group['E_min'][()]
988        out.E_max = group['E_max'][()]
989
990        # Read arrays.
991
992        err = "WMP '{}' array shape is not consistent with the '{}' array shape"
993
994        out.data = group['data'][()]
995
996        out.windows = group['windows'][()]
997
998        out.broaden_poly = group['broaden_poly'][...].astype(bool)
999        if out.broaden_poly.shape[0] != out.windows.shape[0]:
1000            raise ValueError(err.format('broaden_poly', 'windows'))
1001
1002        out.curvefit = group['curvefit'][()]
1003        if out.curvefit.shape[0] != out.windows.shape[0]:
1004            raise ValueError(err.format('curvefit', 'windows'))
1005
1006        # _broaden_wmp_polynomials assumes the curve fit has at least 3 terms.
1007        if out.fit_order < 2:
1008            raise ValueError("Windowed multipole is only supported for "
1009                             "curvefits with 3 or more terms.")
1010
1011        # If HDF5 file was opened here, make sure it gets closed
1012        if need_to_close:
1013            h5file.close()
1014
1015        return out
1016
1017    @classmethod
1018    def from_endf(cls, endf_file, log=False, vf_options=None, wmp_options=None):
1019        """Generate windowed multipole neutron data from an ENDF evaluation.
1020
1021        .. versionadded:: 0.12.1
1022
1023        Parameters
1024        ----------
1025        endf_file : str
1026            Path to ENDF evaluation
1027        log : bool or int, optional
1028            Whether to print running logs (use int for verbosity control)
1029        vf_options : dict, optional
1030            Dictionary of keyword arguments, e.g. {'njoy_error': 0.001},
1031            passed to :func:`openmc.data.multipole.vectfit_nuclide`
1032        wmp_options : dict, optional
1033            Dictionary of keyword arguments, e.g. {'search': True, 'rtol': 0.01},
1034            passed to :func:`openmc.data.WindowedMultipole.from_multipole`
1035
1036        Returns
1037        -------
1038        openmc.data.WindowedMultipole
1039            Resonant cross sections represented in the windowed multipole
1040            format.
1041
1042        """
1043
1044        if vf_options is None:
1045            vf_options = {}
1046
1047        if wmp_options is None:
1048            wmp_options = {}
1049
1050        if log:
1051            vf_options.update(log=log)
1052            wmp_options.update(log=log)
1053
1054        # generate multipole data from EDNF
1055        mp_data = vectfit_nuclide(endf_file, **vf_options)
1056
1057        # windowing
1058        return cls.from_multipole(mp_data, **wmp_options)
1059
1060    @classmethod
1061    def from_multipole(cls, mp_data, search=None, log=False, **kwargs):
1062        """Generate windowed multipole neutron data from multipole data.
1063
1064        Parameters
1065        ----------
1066        mp_data : dictionary or str
1067            Dictionary or Path to the multipole data stored in a pickle file
1068        search : bool, optional
1069            Whether to search for optimal window size and curvefit order.
1070            Defaults to True if no windowing parameters are specified.
1071        log : bool or int, optional
1072            Whether to print running logs (use int for verbosity control)
1073        **kwargs
1074            Keyword arguments passed to :func:`openmc.data.multipole._windowing`
1075
1076        Returns
1077        -------
1078        openmc.data.WindowedMultipole
1079            Resonant cross sections represented in the windowed multipole
1080            format.
1081
1082        """
1083
1084        if isinstance(mp_data, str):
1085            # load multipole data from file
1086            with open(mp_data, 'rb') as f:
1087                mp_data = pickle.load(f)
1088
1089        if search is None:
1090            if 'n_cf' in kwargs and ('n_win' in kwargs or 'spacing' in kwargs):
1091                search = False
1092            else:
1093                search = True
1094
1095        # windowing with specific options
1096        if not search:
1097            # set default value for curvefit order if not specified
1098            if 'n_cf' not in kwargs:
1099                kwargs.update(n_cf=5)
1100            return _windowing(mp_data, log=log, **kwargs)
1101
1102        # search optimal WMP from a range of window sizes and CF orders
1103        if log:
1104            print("Start searching ...")
1105        n_poles = sum([p.size for p in mp_data["poles"]])
1106        n_win_min = max(5, n_poles // 20)
1107        n_win_max = 2000 if n_poles < 2000 else 8000
1108        best_wmp = best_metric = None
1109        for n_w in np.unique(np.linspace(n_win_min, n_win_max, 20, dtype=int)):
1110            for n_cf in range(10, 1, -1):
1111                if log:
1112                    print("Testing N_win={} N_cf={}".format(n_w, n_cf))
1113
1114                # update arguments dictionary
1115                kwargs.update(n_win=n_w, n_cf=n_cf)
1116
1117                # windowing
1118                try:
1119                    wmp = _windowing(mp_data, log=log, **kwargs)
1120                except Exception as e:
1121                    if log:
1122                        print('Failed: ' + str(e))
1123                    break
1124
1125                # select wmp library with metric:
1126                # - performance: average # used poles per window and CF order
1127                # - memory: # windows
1128                metric = -(wmp.poles_per_window * 10. + wmp.fit_order * 1. +
1129                           wmp.n_windows * 0.01)
1130                if best_wmp is None or metric > best_metric:
1131                    if log:
1132                        print("Best library so far.")
1133                    best_wmp = deepcopy(wmp)
1134                    best_metric = metric
1135
1136        # return the best wmp library
1137        if log:
1138            print("Final library: {} poles, {} windows, {:.2g} poles per window, "
1139                  "{} CF order".format(best_wmp.n_poles, best_wmp.n_windows,
1140                   best_wmp.poles_per_window, best_wmp.fit_order))
1141
1142        return best_wmp
1143
1144    def _evaluate(self, E, T):
1145        """Compute scattering, absorption, and fission cross sections.
1146
1147        Parameters
1148        ----------
1149        E : Real
1150            Energy of the incident neutron in eV.
1151        T : Real
1152            Temperature of the target in K.
1153
1154        Returns
1155        -------
1156        3-tuple of Real
1157            Total, absorption, and fission microscopic cross sections at the
1158            given energy and temperature.
1159
1160        """
1161
1162        if E < self.E_min: return (0, 0, 0)
1163        if E > self.E_max: return (0, 0, 0)
1164
1165        # ======================================================================
1166        # Bookkeeping
1167
1168        # Define some frequently used variables.
1169        sqrtkT = sqrt(K_BOLTZMANN * T)
1170        sqrtE = sqrt(E)
1171        invE = 1.0 / E
1172
1173        # Locate us.  The i_window calc omits a + 1 present in F90 because of
1174        # the 1-based vs. 0-based indexing.  Similarly startw needs to be
1175        # decreased by 1.  endw does not need to be decreased because
1176        # range(startw, endw) does not include endw.
1177        i_window = min(self.n_windows - 1,
1178                       int(np.floor((sqrtE - sqrt(self.E_min)) / self.spacing)))
1179        startw = self.windows[i_window, 0] - 1
1180        endw = self.windows[i_window, 1]
1181
1182        # Initialize the ouptut cross sections.
1183        sig_s = 0.0
1184        sig_a = 0.0
1185        sig_f = 0.0
1186
1187        # ======================================================================
1188        # Add the contribution from the curvefit polynomial.
1189
1190        if sqrtkT != 0 and self.broaden_poly[i_window]:
1191            # Broaden the curvefit.
1192            dopp = self.sqrtAWR / sqrtkT
1193            broadened_polynomials = _broaden_wmp_polynomials(E, dopp,
1194                                                             self.fit_order + 1)
1195            for i_poly in range(self.fit_order + 1):
1196                sig_s += (self.curvefit[i_window, i_poly, _FIT_S]
1197                          * broadened_polynomials[i_poly])
1198                sig_a += (self.curvefit[i_window, i_poly, _FIT_A]
1199                          * broadened_polynomials[i_poly])
1200                if self.fissionable:
1201                    sig_f += (self.curvefit[i_window, i_poly, _FIT_F]
1202                              * broadened_polynomials[i_poly])
1203        else:
1204            temp = invE
1205            for i_poly in range(self.fit_order + 1):
1206                sig_s += self.curvefit[i_window, i_poly, _FIT_S] * temp
1207                sig_a += self.curvefit[i_window, i_poly, _FIT_A] * temp
1208                if self.fissionable:
1209                    sig_f += self.curvefit[i_window, i_poly, _FIT_F] * temp
1210                temp *= sqrtE
1211
1212        # ======================================================================
1213        # Add the contribution from the poles in this window.
1214
1215        if sqrtkT == 0.0:
1216            # If at 0K, use asymptotic form.
1217            for i_pole in range(startw, endw):
1218                psi_chi = -1j / (self.data[i_pole, _MP_EA] - sqrtE)
1219                c_temp = psi_chi / E
1220                sig_s += (self.data[i_pole, _MP_RS] * c_temp).real
1221                sig_a += (self.data[i_pole, _MP_RA] * c_temp).real
1222                if self.fissionable:
1223                    sig_f += (self.data[i_pole, _MP_RF] * c_temp).real
1224
1225        else:
1226            # At temperature, use Faddeeva function-based form.
1227            dopp = self.sqrtAWR / sqrtkT
1228            for i_pole in range(startw, endw):
1229                Z = (sqrtE - self.data[i_pole, _MP_EA]) * dopp
1230                w_val = _faddeeva(Z) * dopp * invE * sqrt(pi)
1231                sig_s += (self.data[i_pole, _MP_RS] * w_val).real
1232                sig_a += (self.data[i_pole, _MP_RA] * w_val).real
1233                if self.fissionable:
1234                    sig_f += (self.data[i_pole, _MP_RF] * w_val).real
1235
1236        return sig_s, sig_a, sig_f
1237
1238    def __call__(self, E, T):
1239        """Compute scattering, absorption, and fission cross sections.
1240
1241        Parameters
1242        ----------
1243        E : Real or Iterable of Real
1244            Energy of the incident neutron in eV.
1245        T : Real
1246            Temperature of the target in K.
1247
1248        Returns
1249        -------
1250        3-tuple of Real or 3-tuple of numpy.ndarray
1251            Total, absorption, and fission microscopic cross sections at the
1252            given energy and temperature.
1253
1254        """
1255
1256        fun = np.vectorize(lambda x: self._evaluate(x, T))
1257        return fun(E)
1258
1259    def export_to_hdf5(self, path, mode='a', libver='earliest'):
1260        """Export windowed multipole data to an HDF5 file.
1261
1262        Parameters
1263        ----------
1264        path : str
1265            Path to write HDF5 file to
1266        mode : {'r+', 'w', 'x', 'a'}
1267            Mode that is used to open the HDF5 file. This is the second argument
1268            to the :class:`h5py.File` constructor.
1269        libver : {'earliest', 'latest'}
1270            Compatibility mode for the HDF5 file. 'latest' will produce files
1271            that are less backwards compatible but have performance benefits.
1272
1273        """
1274
1275        # Open file and write version.
1276        with h5py.File(str(path), mode, libver=libver) as f:
1277            f.attrs['filetype'] = np.string_('data_wmp')
1278            f.attrs['version'] = np.array(WMP_VERSION)
1279
1280            g = f.create_group(self.name)
1281
1282            # Write scalars.
1283            g.create_dataset('spacing', data=np.array(self.spacing))
1284            g.create_dataset('sqrtAWR', data=np.array(self.sqrtAWR))
1285            g.create_dataset('E_min', data=np.array(self.E_min))
1286            g.create_dataset('E_max', data=np.array(self.E_max))
1287
1288            # Write arrays.
1289            g.create_dataset('data', data=self.data)
1290            g.create_dataset('windows', data=self.windows)
1291            g.create_dataset('broaden_poly',
1292                             data=self.broaden_poly.astype(np.int8))
1293            g.create_dataset('curvefit', data=self.curvefit)
1294