1 #define TYPEDEPARGS 0, 1, 2
2 #define SINGLEARGS
3 #define REALARGS
4 #define OCTFILENAME comp_heapint
5 #define OCTFILEHELP "Computes heapint.\n\
6 Usage: c = comp_heapint(s, itime, ifreq, a, tol, do_timeinv);\n\
7 Yeah."
8
9
10 #include "ltfat_oct_template_helper.h"
11
12 static inline void
fwd_heapint(const double * s,const double * tgrad,const double * fgrad,const octave_idx_type a,const octave_idx_type M,const octave_idx_type L,const octave_idx_type W,double tol,int phasetype,double * phase)13 fwd_heapint(const double *s, const double *tgrad, const double *fgrad,
14 const octave_idx_type a, const octave_idx_type M,
15 const octave_idx_type L, const octave_idx_type W,
16 double tol, int phasetype, double *phase)
17 {
18 if (phasetype == 2)
19 ltfat_heapint_d(s, tgrad, fgrad, a, M, L, W, tol, phase);
20 else
21 ltfat_heapint_relgrad_d(s, tgrad, fgrad, a, M, L, W, tol,
22 static_cast<ltfat_phaseconvention>(phasetype), phase);
23
24 }
25
26 static inline void
fwd_heapint(const float * s,const float * tgrad,const float * fgrad,const octave_idx_type a,const octave_idx_type M,const octave_idx_type L,const octave_idx_type W,float tol,int phasetype,float * phase)27 fwd_heapint(const float *s, const float *tgrad, const float *fgrad,
28 const octave_idx_type a, const octave_idx_type M,
29 const octave_idx_type L, const octave_idx_type W,
30 float tol, int phasetype, float *phase)
31 {
32 if (phasetype == 2)
33 ltfat_heapint_s(s, tgrad, fgrad, a, M, L, W, tol, phase);
34 else
35 ltfat_heapint_relgrad_s(s, tgrad, fgrad, a, M, L, W, tol,
36 static_cast<ltfat_phaseconvention>(phasetype), phase);
37 }
38
39 template <class LTFAT_TYPE, class LTFAT_REAL, class LTFAT_COMPLEX>
octFunction(const octave_value_list & args,int nargout)40 octave_value_list octFunction(const octave_value_list& args, int nargout)
41 {
42 MArray<LTFAT_TYPE> s = ltfatOctArray<LTFAT_TYPE>(args(0));
43 MArray<LTFAT_TYPE> tgrad = ltfatOctArray<LTFAT_TYPE>(args(1));
44 MArray<LTFAT_TYPE> fgrad = ltfatOctArray<LTFAT_TYPE>(args(2));
45 const octave_idx_type a = args(3).int_value();
46 const double tol = args(4).double_value();
47 const int phasetype = args(5).int_value() == 1? LTFAT_TIMEINV: LTFAT_FREQINV;
48
49 const octave_idx_type M = args(0).rows();
50 const octave_idx_type N = args(0).columns();
51 const octave_idx_type L = N * a;
52 const octave_idx_type W = s.numel() / (M * N);
53
54 MArray<LTFAT_TYPE> phase(dim_vector(M, N, W));
55
56 fwd_heapint(s.data(), tgrad.data(), fgrad.data(), a, M, L, W, tol,
57 phasetype, phase.fortran_vec());
58
59 return octave_value(phase);
60 }
61