1module NFFT
2
3export Plan,nfft_plan
4
5# file ending for OS
6ending = ".so"
7
8if Sys.iswindows()
9	ending = ".dll"
10elseif Sys.isapple()
11	ending = ".dylib"
12end
13
14# path to .so file
15const lib_path = string( @__DIR__, "/libnfftjulia", ending )
16
17# NFFT flags
18PRE_PHI_HUT = UInt32(1)<<0
19FG_PSI = UInt32(1)<<1
20PRE_LIN_PSI = UInt32(1)<<2
21PRE_FG_PSI = UInt32(1)<<3
22PRE_PSI = UInt32(1)<<4
23PRE_FULL_PSI = UInt32(1)<<5
24MALLOC_X = UInt32(1)<<6
25MALLOC_F_HAT = UInt32(1)<<7
26MALLOC_F = UInt32(1)<<8
27FFT_OUT_OF_PLACE = UInt32(1)<<9
28FFTW_INIT = UInt32(1)<<10
29NFFT_SORT_NODES = UInt32(1)<<11
30NFFT_OMP_BLOCKWISE_ADJOINT = UInt32(1)<<12
31PRE_ONE_PSI = (PRE_LIN_PSI| PRE_FG_PSI| PRE_PSI| PRE_FULL_PSI)
32
33# FFTW flags
34FFTW_MEASURE = UInt32(0)
35FFTW_DESTROY_INPUT = UInt32(1)<<0
36FFTW_UNALIGNED = UInt32(1)<<1
37FFTW_CONSERVE_MEMORY = UInt32(1)<<2
38FFTW_EXHAUSTIVE = UInt32(1)<<3
39FFTW_PRESERVE_INPUT = UInt32(1)<<4
40FFTW_PATIENT = UInt32(1)<<5
41FFTW_ESTIMATE = UInt32(1)<<6
42FFTW_WISDOM_ONLY = UInt32(1)<<21
43
44# default flag values
45f1_default_1d = UInt32(PRE_PHI_HUT | PRE_PSI | MALLOC_X | MALLOC_F_HAT | MALLOC_F | FFTW_INIT | FFT_OUT_OF_PLACE)
46f1_default = UInt32(PRE_PHI_HUT | PRE_PSI | MALLOC_X | MALLOC_F_HAT | MALLOC_F | FFTW_INIT | FFT_OUT_OF_PLACE | NFFT_SORT_NODES | NFFT_OMP_BLOCKWISE_ADJOINT)
47f2_default = UInt32(FFTW_ESTIMATE | FFTW_DESTROY_INPUT)
48
49# default window cut off
50default_window_cut_off = ccall(("nfft_get_default_window_cut_off", lib_path),Int64,())
51
52# dummy struct for C
53mutable struct nfft_plan
54end
55
56# NFFT plan struct
57mutable struct Plan{D}
58	N::NTuple{D,Int32}      # bandwidth tuple
59	M::Int32                # number of nodes
60	n::NTuple{D,Int32}      # oversampling per dimension
61	m::Int32                # windows size
62	f1::UInt32              # NFFT flags
63	f2::UInt32              # FFTW flags
64	init_done::Bool         # bool for plan init
65	finalized::Bool	    	# bool for finalizer
66	x::Ref{Float64}         # nodes
67	f::Ref{ComplexF64}      # function values
68	fhat::Ref{ComplexF64}   # Fourier coefficients
69	plan::Ref{nfft_plan}    # plan (C pointer)
70	function Plan{D}(N::NTuple{D,Int32},M::Int32,n::NTuple{D,Int32},m::Int32,f1::UInt32,f2::UInt32) where D
71	# create plan object
72	new(N,M,n,m,f1,f2,false,false)
73	end
74end
75
76# additional constructor for easy use [Plan((N,N),M) instead of Plan{2}((N,N),M)]
77function Plan(N::NTuple{D,Integer},M::Integer) where {D}
78	if any(x->x<=0,N)
79		error("Every entry of N has to be an even, positive integer." )
80	end
81
82	if sum(N .% 2) != 0
83	error("Every entry of N has to be an even, positive integer." )
84	end
85
86	if M <= 0
87	error("M has to be a positive integer." )
88	end
89
90	# convert N to vector for passing it over to C
91	Nv = collect(N)
92
93	# default oversampling
94	n = Array{Int32}(2 .^(ceil.(log.(Nv)/log(2)).+1))
95	n = NTuple{D,Int32}(n)
96
97	# default NFFT flags
98	f1 = UInt32(0)
99
100	if D > 1
101		f1 = f1_default
102	else
103		f1 = f1_default_1d
104	end
105
106	Plan{D}(NTuple{D,Int32}(N),Int32(M),n,Int32(8),f1,f2_default)
107end
108
109function Plan(N::NTuple{D,Integer},M::Integer,n::NTuple{D,Integer},m::Integer=Int32(default_window_cut_off),f1::UInt32=(D > 1 ? f1_default : f1_default_1d),f2::UInt32=f2_default) where {D}
110	@info "You are using the guru interface. Please consult the README if you are having trouble."
111
112    # safety checks
113	if any(x->x<=0,N)
114		error("Every entry of N has to be an even, positive integer." )
115	end
116
117	if sum(N .% 2) != 0
118		error("Every entry of N has to be an even, positive integer." )
119	end
120
121	if M <= 0
122		error("M has to be a positive integer." )
123	end
124
125	if any(x->x<=0,n)
126		error("Every entry of n has to be an even integer." )
127	end
128
129	if n <= N
130		error("Every entry of n has to be larger than the corresponding entry in N." )
131	end
132
133	if sum(n .% 2) != 0
134		error("Every entry of n has to be an even integer." )
135	end
136
137	if m <= 0
138		error("m has to be a positive integer." )
139	end
140
141	Plan{D}(NTuple{D,Int32}(N),Int32(M),NTuple{D,Int32}(n),Int32(m),(f1 | MALLOC_X | MALLOC_F_HAT | MALLOC_F | FFTW_INIT),f2)
142end
143
144# finalizer
145function finalize_plan(P::Plan{D}) where {D}
146	if !P.init_done
147		error("Plan not initialized.")
148	end
149
150	if !P.finalized
151		Core.setfield!(P,:finalized,true)
152		ccall(("jnfft_finalize", lib_path),Nothing,(Ref{nfft_plan},),P.plan)
153	end
154end
155
156# allocate plan memory and init with D,N,M,n,m,f1,f2
157function nfft_init(p::Plan{D}) where {D}
158	# convert N and n to vectors for passing them over to C
159	Nv = collect(p.N)
160	n = collect(p.n)
161
162	# call init for memory allocation
163	ptr = ccall(("jnfft_alloc", lib_path),Ptr{nfft_plan},())
164
165	# set pointer
166	Core.setfield!(p,:plan,ptr)
167
168	# initialize values
169	ccall(("jnfft_init", lib_path),Nothing,(Ref{nfft_plan},Int32,Ref{Int32},Int32,Ref{Int32},Int32,UInt32,UInt32),ptr,D,Nv,p.M,n,p.m,p.f1,p.f2)
170	Core.setfield!(p,:init_done,true)
171	finalizer(finalize_plan,p)
172end
173
174# overwrite dot notation for plan struct in order to use C memory
175function Base.setproperty!(p::Plan{D},v::Symbol,val) where {D}
176	# init plan if not done [usually with setting nodes]
177	if !p.init_done
178		nfft_init(p)
179	end
180
181	# prevent bad stuff from happening
182	if p.finalized
183		error("Plan already finalized")
184	end
185
186	# setting nodes, verification of correct size dxM
187	if v == :x
188		if D == 1
189			if typeof(val) != Vector{Float64}
190				error("x has to be a Float64 vector.")
191			end
192			if size(val)[1] != p.M
193				error("x has to be a Float64 vector of length M.")
194			end
195		else
196			if typeof(val) != Array{Float64,2}
197				error("x has to be a Float64 matrix.")
198			end
199			if size(val)[1] != D || size(val)[2] != p.M
200				error("x has to be a Float64 matrix of size dxM.")
201			end
202		end
203		ptr = ccall(("jnfft_set_x",lib_path),Ptr{Float64},(Ref{nfft_plan},Ref{Cdouble}),p.plan,val)
204		Core.setfield!(p,v,ptr)
205    # setting values
206    elseif v == :f
207		if typeof(val) != Array{ComplexF64,1}
208			error("f has to be a ComplexFloat64 vector.")
209		end
210		if size(val)[1] != p.M
211			error("f has to be a ComplexFloat64 vector of size M.")
212		end
213		ptr = ccall(("jnfft_set_f",lib_path),Ptr{ComplexF64},(Ref{nfft_plan},Ref{ComplexF64}),p.plan,val)
214		Core.setfield!(p,v,ptr)
215	# setting Fourier coefficients
216	elseif v == :fhat
217		if typeof(val) != Array{ComplexF64,1}
218			error("fhat has to be a ComplexFloat64 vector.")
219		end
220		l = prod(p.N)
221		if size(val)[1] != l
222			error("fhat has to be a ComplexFloat64 vector of size prod(N).")
223		end
224		ptr = ccall(("jnfft_set_fhat",lib_path),Ptr{ComplexF64},(Ref{nfft_plan},Ref{ComplexF64}),p.plan,val)
225		Core.setfield!(p,v,ptr)
226    # prevent modification of NFFT plan pointer
227	elseif v == :plan
228		@warn "You can't modify the C pointer to the NFFT plan."
229	elseif v == :num_threads
230		@warn "You can't currently modify the number of threads."
231	elseif v == :init_done
232		@warn "You can't modify this flag."
233	elseif v == :N
234		@warn "You can't modify the bandwidth, please create an additional plan."
235	elseif v == :M
236		@warn "You can't modify the number of nodes, please create an additional plan."
237	elseif v == :n
238		@warn "You can't modify the oversampling parameter, please create an additional plan."
239	elseif v == :m
240		@warn "You can't modify the window size, please create an additional plan."
241	elseif v == :f1
242		@warn "You can't modify the NFFT flags, please create an additional plan."
243	elseif v == :f2
244		@warn "You can't modify the FFTW flags, please create an additional plan."
245	# handle other set operations the default way
246	else
247		Core.setfield!(p,v,val)
248	end
249end
250
251# overwrite dot notation for plan struct in order to use C memory
252function Base.getproperty(p::Plan{D},v::Symbol) where {D}
253	if v == :x
254		if !isdefined(p,:x)
255			error("x is not set.")
256		end
257		ptr = Core.getfield(p,:x)
258		if D==1
259			return unsafe_wrap(Vector{Float64},ptr,p.M)             # get nodes from C memory and convert to Julia type
260		else
261			return unsafe_wrap(Matrix{Float64},ptr,(D,Int64(p.M)))  # get nodes from C memory and convert to Julia type
262		end
263	elseif v == :num_threads
264		return ccall(("nfft_get_num_threads", lib_path),Int64,())
265	elseif v == :f
266		if !isdefined(p,:f)
267			error("f is not set.")
268		end
269		ptr = Core.getfield(p,:f)
270		return unsafe_wrap(Vector{ComplexF64},ptr,p.M)  # get function values from C memory and convert to Julia type
271	elseif v == :fhat
272		if !isdefined(p,:fhat)
273			error("fhat is not set.")
274		end
275		ptr = Core.getfield(p,:fhat)
276		return unsafe_wrap(Vector{ComplexF64},ptr,prod(p.N)) # get Fourier coefficients from C memory and convert to Julia type
277	else
278		return Core.getfield(p,v)
279	end
280end
281
282# nfft trafo direct [call with NFFT.trafo_direct outside module]
283function trafo_direct(P::Plan{D}) where {D}
284	# prevent bad stuff from happening
285	if P.finalized
286		error("Plan already finalized")
287	end
288
289	if !isdefined(P, :fhat)
290		error("fhat has not been set.")
291	end
292
293	if !isdefined(P,:x)
294		error("x has not been set.")
295	end
296
297	ptr = ccall(("jnfft_trafo_direct",lib_path),Ptr{ComplexF64},(Ref{nfft_plan},),P.plan)
298	Core.setfield!(P,:f,ptr)
299end
300
301# adjoint trafo direct [call with NFFT.adjoint_direct outside module]
302function adjoint_direct(P::Plan{D}) where {D}
303	# prevent bad stuff from happening
304	if P.finalized
305		error("Plan already finalized")
306	end
307	if !isdefined(P, :f)
308		error("f has not been set.")
309	end
310	if !isdefined(P,:x)
311		error("x has not been set.")
312	end
313	ptr = ccall(("jnfft_adjoint_direct",lib_path),Ptr{ComplexF64},(Ref{nfft_plan},),P.plan)
314	Core.setfield!(P,:fhat,ptr)
315end
316
317# nfft trafo [call with NFFT.trafo outside module]
318function trafo(P::Plan{D}) where {D}
319	# prevent bad stuff from happening
320	if P.finalized
321		error("Plan already finalized")
322	end
323	if !isdefined(P, :fhat)
324		error("fhat has not been set.")
325	end
326	if !isdefined(P,:x)
327		error("x has not been set.")
328	end
329	ptr = ccall(("jnfft_trafo",lib_path),Ptr{ComplexF64},(Ref{nfft_plan},),P.plan)
330	Core.setfield!(P,:f,ptr)
331end
332
333# adjoint trafo [call with NFFT.adjoint outside module]
334function adjoint(P::Plan{D}) where {D}
335	# prevent bad stuff from happening
336	if P.finalized
337		error("Plan already finalized")
338	end
339	if !isdefined(P, :f)
340		error("f has not been set.")
341 	end
342	if !isdefined(P,:x)
343		error("x has not been set.")
344	end
345	ptr = ccall(("jnfft_adjoint",lib_path),Ptr{ComplexF64},(Ref{nfft_plan},),P.plan)
346	Core.setfield!(P,:fhat,ptr)
347end
348
349# module end
350end
351