1 #define TYPEDEPARGS 0, 1, 2
2 #define SINGLEARGS
3 #define REALARGS
4 #define OCTFILENAME comp_heapint
5 #define OCTFILEHELP "Computes heapint.\n\
6                     Usage: c = comp_heapint(s, itime, ifreq, a, tol, do_timeinv);\n\
7                     Yeah."
8 
9 
10 #include "ltfat_oct_template_helper.h"
11 
12 static inline void
fwd_heapint(const double * s,const double * tgrad,const double * fgrad,const octave_idx_type a,const octave_idx_type M,const octave_idx_type L,const octave_idx_type W,double tol,int phasetype,double * phase)13 fwd_heapint(const double *s, const double *tgrad, const double *fgrad,
14             const octave_idx_type a, const octave_idx_type M,
15             const octave_idx_type L, const octave_idx_type W,
16             double tol, int phasetype, double *phase)
17 {
18     if (phasetype == 2)
19         ltfat_heapint_d(s, tgrad, fgrad, a, M, L, W, tol, phase);
20     else
21         ltfat_heapint_relgrad_d(s, tgrad, fgrad, a, M, L, W, tol,
22                           static_cast<ltfat_phaseconvention>(phasetype), phase);
23 
24 }
25 
26 static inline void
fwd_heapint(const float * s,const float * tgrad,const float * fgrad,const octave_idx_type a,const octave_idx_type M,const octave_idx_type L,const octave_idx_type W,float tol,int phasetype,float * phase)27 fwd_heapint(const float *s, const float *tgrad, const float *fgrad,
28             const octave_idx_type a, const octave_idx_type M,
29             const octave_idx_type L, const octave_idx_type W,
30             float tol, int phasetype, float *phase)
31 {
32     if (phasetype == 2)
33         ltfat_heapint_s(s, tgrad, fgrad, a, M, L, W, tol, phase);
34     else
35         ltfat_heapint_relgrad_s(s, tgrad, fgrad, a, M, L, W, tol,
36                           static_cast<ltfat_phaseconvention>(phasetype), phase);
37 }
38 
39 template <class LTFAT_TYPE, class LTFAT_REAL, class LTFAT_COMPLEX>
octFunction(const octave_value_list & args,int nargout)40 octave_value_list octFunction(const octave_value_list& args, int nargout)
41 {
42     MArray<LTFAT_TYPE> s = ltfatOctArray<LTFAT_TYPE>(args(0));
43     MArray<LTFAT_TYPE> tgrad = ltfatOctArray<LTFAT_TYPE>(args(1));
44     MArray<LTFAT_TYPE> fgrad = ltfatOctArray<LTFAT_TYPE>(args(2));
45     const octave_idx_type a  = args(3).int_value();
46     const double tol   = args(4).double_value();
47     const int phasetype = args(5).int_value() == 1? LTFAT_TIMEINV: LTFAT_FREQINV;
48 
49     const octave_idx_type M = args(0).rows();
50     const octave_idx_type N = args(0).columns();
51     const octave_idx_type L = N * a;
52     const octave_idx_type W = s.numel() / (M * N);
53 
54     MArray<LTFAT_TYPE> phase(dim_vector(M, N, W));
55 
56     fwd_heapint(s.data(), tgrad.data(), fgrad.data(), a, M, L, W, tol,
57                 phasetype, phase.fortran_vec());
58 
59     return octave_value(phase);
60 }
61