1module NFCT 2 3export NFCTplan,nfct_plan 4 5ending = ".so" 6 7if Sys.iswindows() 8 ending = ".dll" 9elseif Sys.isapple() 10 ending = ".dylib" 11end 12 13# path to .so file 14const lib_path = string( @__DIR__, "/libnfctjulia", ending ) 15 16# NFCT flags 17PRE_PHI_HUT = UInt32(1)<<0 18FG_PSI = UInt32(1)<<1 19PRE_LIN_PSI = UInt32(1)<<2 # 20PRE_FG_PSI = UInt32(1)<<3 # 21PRE_PSI = UInt32(1)<<4 # 22PRE_FULL_PSI = UInt32(1)<<5 # 23MALLOC_X = UInt32(1)<<6 24MALLOC_F_HAT = UInt32(1)<<7 25MALLOC_F = UInt32(1)<<8 26FFT_OUT_OF_PLACE = UInt32(1)<<9 27FFTW_INIT = UInt32(1)<<10 28NFCT_SORT_NODES = UInt32(1)<<11 29NFCT_OMP_BLOCKWISE_ADJOINT = UInt32(1)<<12 30PRE_ONE_PSI = (PRE_LIN_PSI| PRE_FG_PSI| PRE_PSI| PRE_FULL_PSI) 31 32# FFTW flags 33FFTW_MEASURE = UInt32(0) 34FFTW_DESTROY_INPUT = UInt32(1)<<0 35FFTW_EXHAUSTIVE = UInt32(1)<<3 36FFTW_PATIENT = UInt32(1)<<5 37FFTW_ESTIMATE = UInt32(1)<<6 38 39#default flag values 40f1_default_1d = UInt32(PRE_PHI_HUT | PRE_PSI | MALLOC_X | MALLOC_F_HAT | MALLOC_F | FFTW_INIT | FFT_OUT_OF_PLACE) 41f1_default = UInt32(PRE_PHI_HUT | PRE_PSI | MALLOC_X | MALLOC_F_HAT | MALLOC_F | FFTW_INIT | FFT_OUT_OF_PLACE | NFCT_SORT_NODES | NFCT_OMP_BLOCKWISE_ADJOINT) 42f2_default = UInt32(FFTW_ESTIMATE | FFTW_DESTROY_INPUT) 43 44# dummy struct for C 45mutable struct nfct_plan 46end 47 48# NFCT plan struct 49mutable struct NFCTplan{D} 50 N::NTuple{D,Int32} # bandwidth tuple 51 M::Int32 # number of nodes 52 n::NTuple{D,Int32} # oversampling per dimension 53 m::Int32 # windows size 54 f1::UInt32 # NFCT flags 55 f2::UInt32 # FFTW flags 56 init_done::Bool # bool for plan init 57 finalized::Bool # bool for finalizer 58 x::Ref{Float64} # nodes 59 f::Ref{Float64} # function values 60 fhat::Ref{Float64} # Fourier coefficients 61 plan::Ref{nfct_plan} # plan (C pointer) 62 function NFCTplan{D}(N::NTuple{D,Int32},M::Int32,n::NTuple{D,Int32},m::Int32,f1::UInt32,f2::UInt32) where D 63 # create plan object 64 new(N,M,n,m,f1,f2,false,false) 65 end 66end 67 68# additional constructor for easy use [NFCTplan((N,N),M) instead of NFCTplan{2}((N,N),M)] 69function NFCTplan(N::NTuple{D,Integer},M::Integer) where {D} 70 if any(x->x<=0,N) 71 error("Every entry of N has to be an even, positive integer." ) 72 end 73 74 if M <= 0 75 error("M has to be a positive integer." ) 76 end 77 78 # convert N to vector for passing it over to C 79 Nv = collect(N) 80 81 # default oversampling 82 n = Array{Int32}(2 .^(ceil.(log.(Nv)/log(2)).+1)) 83 n = NTuple{D,Int32}(n) 84 85 # default NFCT flags 86 f1 = UInt32(0) 87 88 if D > 1 89 f1 = f1_default 90 else 91 f1 = f1_default_1d 92 end 93 94 NFCTplan{D}(NTuple{D,Int32}(N),Int32(M),n,Int32(8),f1,f2_default) 95end 96 97function NFCTplan(N::NTuple{D,Integer},M::Integer,n::NTuple{D,Integer},m::Integer=Int32(8),f1::UInt32=(D > 1 ? f1_default : f1_default_1d),f2::UInt32=f2_default) where {D} 98 @info "You are using the guru interface. Please consult the README if you are having trouble." 99 100 # safety checks 101 if any(x->x<=0,N) 102 error("Every entry of N has to be an even, positive integer." ) 103 end 104 105 if sum(N .% 2) != 0 106 error("Every entry of N has to be an even, positive integer." ) 107 end 108 109 if M <= 0 110 error("M has to be a positive integer." ) 111 end 112 113 if any(x->x<=0,n) 114 error("Every entry of n has to be an even integer." ) 115 end 116 117 if n <= N 118 error("Every entry of n has to be larger than the corresponding entry in N." ) 119 end 120 121 if sum(n .% 2) != 0 122 error("Every entry of n has to be an even integer." ) 123 end 124 125 if m <= 0 126 error("m has to be a positive integer." ) 127 end 128 129 NFCTplan{D}(NTuple{D,Int32}(N),Int32(M),NTuple{D,Int32}(n),Int32(m),(f1 | MALLOC_X | MALLOC_F_HAT | MALLOC_F | FFTW_INIT),f2) 130end 131 132# finalizer 133function finalize_plan(P::NFCTplan{D}) where {D} 134 if !P.init_done 135 error("NFCTplan not initialized.") 136 end 137 138 if !P.finalized 139 Core.setfield!(P,:finalized,true) 140 ccall(("jnfct_finalize", lib_path),Nothing,(Ref{nfct_plan},),P.plan) 141 end 142end 143 144# allocate plan memory and init with D,N,M,n,m,f1,f2 145function nfct_init(p::NFCTplan{D}) where {D} 146 # convert N and n to vectors for passing them over to C 147 Nv = collect(p.N) 148 n = collect(p.n) 149 150 # call init for memory allocation 151 ptr = ccall(("jnfct_alloc", lib_path),Ptr{nfct_plan},()) 152 153 # set pointer 154 Core.setfield!(p,:plan,ptr) 155 156 # initialize values 157 ccall(("jnfct_init", lib_path),Nothing,(Ref{nfct_plan},Int32,Ref{Int32},Int32,Ref{Int32},Int32,UInt32,UInt32),ptr,D,Nv,p.M,n,p.m,p.f1,p.f2) 158 Core.setfield!(p,:init_done,true) 159 finalizer(finalize_plan,p) 160end 161 162# overwrite dot notation for plan struct in order to use C memory 163function Base.setproperty!(p::NFCTplan{D},v::Symbol,val) where {D} 164 # init plan if not done [usually with setting nodes] 165 if !p.init_done 166 nfct_init(p) 167 end 168 169 # prevent bad stuff from happening 170 if p.finalized 171 error("NFCTplan already finalized") 172 end 173 174 # setting nodes, verification of correct size dxM 175 if v == :x 176 if D == 1 177 if typeof(val) != Vector{Float64} 178 error("x has to be a Float64 vector.") 179 end 180 if size(val)[1] != p.M 181 error("x has to be a Float64 vector of length M.") 182 end 183 else 184 if typeof(val) != Array{Float64,2} 185 error("x has to be a Float64 matrix.") 186 end 187 if size(val)[1] != D || size(val)[2] != p.M 188 error("x has to be a Float64 matrix of size dxM.") 189 end 190 end 191 ptr = ccall(("jnfct_set_x",lib_path),Ptr{Float64},(Ref{nfct_plan},Ref{Cdouble}),p.plan,val) 192 Core.setfield!(p,v,ptr) 193 194 # setting values 195 elseif v == :f 196 if typeof(val) != Array{Float64,1} 197 error("f has to be a Float64 vector.") 198 end 199 if size(val)[1] != p.M 200 error("f has to be a Float64 vector of size M.") 201 end 202 ptr = ccall(("jnfct_set_f",lib_path),Ptr{Float64},(Ref{nfct_plan},Ref{Float64}),p.plan,val) 203 Core.setfield!(p,v,ptr) 204 205 # setting Fourier coefficients 206 elseif v == :fhat 207 if typeof(val) != Array{Float64,1} 208 error("fhat has to be a Float64 vector.") 209 end 210 l = prod(p.N) 211 if size(val)[1] != l 212 error("fhat has to be a Float64 vector of size prod(N).") 213 end 214 ptr = ccall(("jnfct_set_fhat",lib_path),Ptr{Float64},(Ref{nfct_plan},Ref{Float64}),p.plan,val) 215 Core.setfield!(p,v,ptr) 216 217 # prevent modification of NFCT plan pointer 218 elseif v == :plan 219 @warn "You can't modify the C pointer to the NFCT plan." 220 elseif v == :num_threads 221 @warn "You can't currently modify the number of threads." 222 elseif v == :init_done 223 @warn "You can't modify this flag." 224 elseif v == :N 225 @warn "You can't modify the bandwidth, please create an additional plan." 226 elseif v == :M 227 @warn "You can't modify the number of nodes, please create an additional plan." 228 elseif v == :n 229 @warn "You can't modify the oversampling parameter, please create an additional plan." 230 elseif v == :m 231 @warn "You can't modify the window size, please create an additional plan." 232 elseif v == :f1 233 @warn "You can't modify the NFCT flags, please create an additional plan." 234 elseif v == :f2 235 @warn "You can't modify the FFTW flags, please create an additional plan." 236 # handle other set operations the default way 237 else 238 Core.setfield!(p,v,val) 239 end 240end 241 242# overwrite dot notation for plan struct in order to use C memory 243function Base.getproperty(p::NFCTplan{D},v::Symbol) where {D} 244 if v == :x 245 if !isdefined(p,:x) 246 error("x is not set.") 247 end 248 ptr = Core.getfield(p,:x) 249 if D==1 250 return unsafe_wrap(Vector{Float64},ptr,p.M) # get nodes from C memory and convert to Julia type 251 else 252 return unsafe_wrap(Matrix{Float64},ptr,(D,Int64(p.M))) # get nodes from C memory and convert to Julia type 253 end 254 elseif v == :num_threads 255 return ccall(("nfft_get_num_threads", lib_path),Int64,()) 256 elseif v == :f 257 if !isdefined(p,:f) 258 error("f is not set.") 259 end 260 ptr = Core.getfield(p,:f) 261 return unsafe_wrap(Vector{Float64},ptr,p.M) # get function values from C memory and convert to Julia type 262 elseif v == :fhat 263 if !isdefined(p,:fhat) 264 error("fhat is not set.") 265 end 266 ptr = Core.getfield(p,:fhat) 267 return unsafe_wrap(Vector{Float64},ptr,prod(p.N)) # get Fourier coefficients from C memory and convert to Julia type 268 else 269 return Core.getfield(p,v) 270 end 271end 272 273# nfct trafo direct [call with NFCT.trafo_direct outside module] 274function trafo_direct(P::NFCTplan{D}) where {D} 275 # prevent bad stuff from happening 276 if P.finalized 277 error("NFCTplan already finalized") 278 end 279 280 if !isdefined(P, :fhat) 281 error("fhat has not been set.") 282 end 283 284 if !isdefined(P,:x) 285 error("x has not been set.") 286 end 287 288 ptr = ccall(("jnfct_trafo_direct",lib_path),Ptr{Float64},(Ref{nfct_plan},),P.plan) 289 Core.setfield!(P,:f,ptr) 290end 291 292# adjoint trafo direct [call with NFCT.adjoint_direct outside module] 293function adjoint_direct(P::NFCTplan{D}) where {D} 294 # prevent bad stuff from happening 295 if P.finalized 296 error("NFCTplan already finalized") 297 end 298 if !isdefined(P, :f) 299 error("f has not been set.") 300 end 301 if !isdefined(P,:x) 302 error("x has not been set.") 303 end 304 ptr = ccall(("jnfct_adjoint_direct",lib_path),Ptr{Float64},(Ref{nfct_plan},),P.plan) 305 Core.setfield!(P,:fhat,ptr) 306end 307 308# nfct trafo [call with NFCT.trafo outside module] 309function trafo(P::NFCTplan{D}) where {D} 310 # prevent bad stuff from happening 311 if P.finalized 312 error("NFCTplan already finalized") 313 end 314 if !isdefined(P, :fhat) 315 error("fhat has not been set.") 316 end 317 if !isdefined(P,:x) 318 error("x has not been set.") 319 end 320 ptr = ccall(("jnfct_trafo",lib_path),Ptr{Float64},(Ref{nfct_plan},),P.plan) 321 Core.setfield!(P,:f,ptr) 322end 323 324# adjoint trafo [call with NFCT.adjoint outside module] 325function adjoint(P::NFCTplan{D}) where {D} 326 # prevent bad stuff from happening 327 if P.finalized 328 error("NFCTplan already finalized") 329 end 330 if !isdefined(P, :f) 331 error("f has not been set.") 332 end 333 if !isdefined(P,:x) 334 error("x has not been set.") 335 end 336 ptr = ccall(("jnfct_adjoint",lib_path),Ptr{Float64},(Ref{nfct_plan},),P.plan) 337 Core.setfield!(P,:fhat,ptr) 338end 339 340# module end 341end 342