1!! Copyright (C) 2005-2010 Florian Lorenzen, Heiko Appel, X. Andrade
2!!
3!! This program is free software; you can redistribute it and/or modify
4!! it under the terms of the GNU General Public License as published by
5!! the Free Software Foundation; either version 2, or (at your option)
6!! any later version.
7!!
8!! This program is distributed in the hope that it will be useful,
9!! but WITHOUT ANY WARRANTY; without even the implied warranty of
10!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11!! GNU General Public License for more details.
12!!
13!! You should have received a copy of the GNU General Public License
14!! along with this program; if not, write to the Free Software
15!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
16!! 02110-1301, USA.
17!!
18
19#include "global.h"
20
21module boundaries_oct_m
22  use accel_oct_m
23  use batch_oct_m
24  use batch_ops_oct_m
25  use global_oct_m
26  use math_oct_m
27  use messages_oct_m
28  use mesh_oct_m
29  use mpi_oct_m
30  use mpi_debug_oct_m
31  use namespace_oct_m
32  use par_vec_oct_m
33  use parser_oct_m
34  use profiling_oct_m
35  use simul_box_oct_m
36  use subarray_oct_m
37  use types_oct_m
38  use unit_oct_m
39  use unit_system_oct_m
40
41  implicit none
42
43  private
44
45  type boundaries_t
46    private
47    type(mesh_t), pointer :: mesh
48    integer          :: nper             !< the number of points that correspond to pbc
49    integer, pointer :: per_points(:, :) !< (1:2, 1:nper) the list of points that correspond to pbc
50    integer, pointer :: per_send(:, :)
51    integer, pointer :: per_recv(:, :)
52    integer, pointer :: nsend(:)
53    integer, pointer :: nrecv(:)
54    type(accel_mem_t) :: buff_per_points
55    type(accel_mem_t) :: buff_per_send
56    type(accel_mem_t) :: buff_per_recv
57    type(accel_mem_t) :: buff_nsend
58    type(accel_mem_t) :: buff_nrecv
59    logical, public   :: spiralBC           !< set .true. when SpiralBoundaryCondition are set in the input file
60    logical, public   :: spiral             !< set .true. after first time step IF spiralBC == .true. (see td_run in td.F90)
61    FLOAT,   public   :: spiral_q(MAX_DIM)
62  end type boundaries_t
63
64  public ::                        &
65    boundaries_t,                  &
66    boundaries_nullify,            &
67    boundaries_init,               &
68    boundaries_end,                &
69    boundaries_set
70
71  public ::                        &
72    pv_handle_batch_t,             &
73    dvec_ghost_update,             &
74    zvec_ghost_update,             &
75    dghost_update_batch_start,     &
76    zghost_update_batch_start,     &
77    dghost_update_batch_finish,    &
78    zghost_update_batch_finish
79
80  integer, parameter, public ::    &
81    POINT_BOUNDARY = 1,            &
82    POINT_INNER    = 2
83
84  type pv_handle_batch_t
85    private
86    type(batch_t)        :: ghost_send
87    integer,     pointer :: requests(:)
88    integer              :: nnb
89    ! these are needed for CL
90    FLOAT, pointer       :: drecv_buffer(:)
91    CMPLX, pointer       :: zrecv_buffer(:)
92    FLOAT, pointer       :: dsend_buffer(:)
93    CMPLX, pointer       :: zsend_buffer(:)
94    type(batch_t),   pointer :: v_local
95    type(pv_t),      pointer :: vp
96  end type pv_handle_batch_t
97
98  type(profile_t), save :: prof_start
99  type(profile_t), save :: prof_wait
100  type(profile_t), save :: prof_update
101  type(profile_t), save :: set_bc_prof
102  type(profile_t), save :: set_bc_comm_prof
103  type(profile_t), save :: set_bc_precomm_prof
104  type(profile_t), save :: set_bc_postcomm_prof
105
106  interface boundaries_set
107    module procedure boundaries_set_batch
108    module procedure dboundaries_set_single
109    module procedure zboundaries_set_single
110  end interface boundaries_set
111
112contains
113
114  ! ---------------------------------------------------------
115  elemental subroutine boundaries_nullify(this)
116    type(boundaries_t), intent(out) :: this
117
118    nullify(this%mesh)
119    this%nper = 0
120    nullify(this%per_points, this%per_send, this%per_recv)
121    nullify(this%nsend, this%nrecv)
122    call accel_mem_nullify(this%buff_per_points)
123    call accel_mem_nullify(this%buff_per_send)
124    call accel_mem_nullify(this%buff_per_recv)
125    call accel_mem_nullify(this%buff_nsend)
126    call accel_mem_nullify(this%buff_nrecv)
127    this%spiralBC = .false.
128    this%spiral = .false.
129    this%spiral_q(1:MAX_DIM) = M_ZERO
130
131  end subroutine boundaries_nullify
132
133  ! ---------------------------------------------------------
134  subroutine boundaries_init(this, namespace, mesh)
135    type(boundaries_t),   intent(out)   :: this
136    type(namespace_t),       intent(in)    :: namespace
137    type(mesh_t), target, intent(in)    :: mesh
138
139    integer :: sp, ip, ip_inner, iper, ip_global, idir
140#ifdef HAVE_MPI
141    integer :: ip_inner_global, ipart
142    integer, allocatable :: recv_rem_points(:, :)
143    integer :: nper_recv
144    integer, allocatable :: send_buffer(:)
145    integer :: bsize, status(MPI_STATUS_SIZE)
146#endif
147    type(block_t) :: blk
148
149    PUSH_SUB(boundaries_init)
150
151    this%mesh => mesh
152
153    nullify(this%per_points)
154
155    if (simul_box_is_periodic(mesh%sb)) then
156
157      !%Variable SpiralBoundaryCondition
158      !%Type logical
159      !%Default no
160      !%Section Mesh
161      !%Description
162      !% (Experimental) If set to yes, Octopus will apply spin-spiral boundary conditions.
163      !% The momentum of the spin spiral is defined by the variable
164      !% <tt>TDMomentumTransfer</tt>
165      !%End
166      call parse_variable(namespace, 'SpiralBoundaryCondition', .false., this%spiralBC)
167      if(this%spiralBC) then
168        if(parse_is_defined(namespace, 'TDMomentumTransfer')) then
169          if(parse_block(namespace, 'TDMomentumTransfer', blk)==0) then
170            do idir = 1, MAX_DIM
171             call parse_block_float(blk, 0, idir - 1, this%spiral_q(idir))
172             this%spiral_q(idir) = units_to_atomic(unit_one / units_inp%length, this%spiral_q(idir))
173            end do
174            call messages_experimental("SpiralBoundaryCondition")
175          else
176            message(1) = "TDMomentumTransfer must be defined if SpiralBoundaryCondition=yes"
177            call messages_fatal(1, namespace=namespace)
178          end if
179        else
180         message(1) = "TDMomentumTransfer must be defined if SpiralBoundaryCondition=yes"
181         call messages_fatal(1, namespace=namespace)
182        end if
183      end if
184
185      sp = mesh%np
186      if(mesh%parallel_in_domains) sp = mesh%np + mesh%vp%np_ghost
187
188      !count the number of points that are periodic
189      this%nper = 0
190#ifdef HAVE_MPI
191      nper_recv = 0
192#endif
193      do ip = sp + 1, mesh%np_part
194
195        ip_global = ip
196
197#ifdef HAVE_MPI
198        !translate to a global point
199        if(mesh%parallel_in_domains) ip_global = mesh%vp%bndry(ip - sp - 1 + mesh%vp%xbndry)
200#endif
201
202        ip_inner = mesh_periodic_point(mesh, ip_global, ip)
203
204#ifdef HAVE_MPI
205        !translate back to a local point
206        if(mesh%parallel_in_domains) ip_inner = vec_global2local(mesh%vp, ip_inner, mesh%vp%partno)
207#endif
208
209        ! If the point is the periodic of another point, is not zero
210        ! (this might happen in the parallel case) and is inside the
211        ! grid then we have to copy it from the grid points.
212        !
213        ! If the point index is larger than mesh%np then it is the
214        ! periodic copy of a point that is zero, so we don`t count it
215        ! as it will be initialized to zero anyway. For different
216        ! mixed boundary conditions the last check should be removed.
217        !
218        if(ip /= ip_inner .and. ip_inner /= 0 .and. ip_inner <= mesh%np) then
219          this%nper = this%nper + 1
220#ifdef HAVE_MPI
221        else if(mesh%parallel_in_domains .and. ip /= ip_inner) then
222          nper_recv = nper_recv + 1
223#endif
224        end if
225      end do
226
227      SAFE_ALLOCATE(this%per_points(1:2, 1:this%nper))
228
229#ifdef HAVE_MPI
230      if(mesh%parallel_in_domains) then
231        SAFE_ALLOCATE(this%per_recv(1:nper_recv, 1:mesh%vp%npart))
232        SAFE_ALLOCATE(recv_rem_points(1:nper_recv, 1:mesh%vp%npart))
233        SAFE_ALLOCATE(this%nrecv(1:mesh%vp%npart))
234        this%nrecv = 0
235      end if
236#endif
237
238      iper = 0
239      do ip = sp + 1, mesh%np_part
240
241        ip_global = ip
242
243        !translate to a global point
244#ifdef HAVE_MPI
245        if(mesh%parallel_in_domains) ip_global = mesh%vp%bndry(ip - sp - 1 + mesh%vp%xbndry)
246#endif
247
248        ip_inner = mesh_periodic_point(mesh, ip_global, ip)
249
250        !translate to local (and keep a copy of the global)
251#ifdef HAVE_MPI
252        if(mesh%parallel_in_domains) then
253          ip_inner_global = ip_inner
254          ip_inner = vec_global2local(mesh%vp, ip_inner, mesh%vp%partno)
255        end if
256#endif
257
258        if(ip /= ip_inner .and. ip_inner /= 0 .and. ip_inner <= mesh%np) then
259          iper = iper + 1
260          this%per_points(POINT_BOUNDARY, iper) = ip
261          this%per_points(POINT_INNER, iper) = ip_inner
262
263#ifdef HAVE_MPI
264        else if(mesh%parallel_in_domains .and. ip /= ip_inner) then ! the point is in another node
265          ! find in which paritition it is
266          do ipart = 1, mesh%vp%npart
267            if(ipart == mesh%vp%partno) cycle
268
269            ip_inner = vec_global2local(mesh%vp, ip_inner_global, ipart)
270
271            if(ip_inner /= 0) then
272              if(ip_inner <= mesh%vp%np_local_vec(ipart)) then
273                ! count the points to receive from each node
274                this%nrecv(ipart) = this%nrecv(ipart) + 1
275                ! and store the number of the point
276                this%per_recv(this%nrecv(ipart), ipart) = ip
277                ! and where it is in the other partition
278                recv_rem_points(this%nrecv(ipart), ipart) = ip_inner
279
280                ASSERT(mesh%vp%rank /= ipart - 1) ! if we are here, the point must be in another node
281
282                exit
283              end if
284            end if
285
286          end do
287#endif
288        end if
289      end do
290
291#ifdef HAVE_MPI
292      if(mesh%parallel_in_domains) then
293
294        ! first we allocate the buffer to be able to use MPI_Bsend
295        bsize = mesh%vp%npart - 1 + nper_recv + MPI_BSEND_OVERHEAD*2*(mesh%vp%npart - 1)
296        SAFE_ALLOCATE(send_buffer(1:bsize))
297        call MPI_Buffer_attach(send_buffer(1), bsize*4, mpi_err)
298
299        ! Now we communicate to each node the points they will have to
300        ! send us. Probably this could be done without communication,
301        ! but this way it seems simpler to implement.
302
303        ! We send the number of points we expect to receive.
304        do ipart = 1, mesh%vp%npart
305          if(ipart == mesh%vp%partno) cycle
306          call MPI_Bsend(this%nrecv(ipart), 1, MPI_INTEGER, ipart - 1, 0, mesh%vp%comm, mpi_err)
307        end do
308
309        ! And we receive it
310        SAFE_ALLOCATE(this%nsend(1:mesh%vp%npart))
311        this%nsend = 0
312        do ipart = 1, mesh%vp%npart
313          if(ipart == mesh%vp%partno) cycle
314          call MPI_Recv(this%nsend(ipart), 1, MPI_INTEGER, ipart - 1, 0, mesh%vp%comm, status, mpi_err)
315        end do
316
317        ! Now we send the indices of the points
318        do ipart = 1, mesh%vp%npart
319          if(ipart == mesh%vp%partno .or. this%nrecv(ipart) == 0) cycle
320          call MPI_Bsend(recv_rem_points(1, ipart), this%nrecv(ipart), MPI_INTEGER, ipart - 1, 1, mesh%vp%comm, mpi_err)
321        end do
322
323        SAFE_ALLOCATE(this%per_send(1:maxval(this%nsend), 1:mesh%vp%npart))
324
325        ! And we receive them
326        do ipart = 1, mesh%vp%npart
327          if(ipart == mesh%vp%partno .or. this%nsend(ipart) == 0) cycle
328          call MPI_Recv(this%per_send(1, ipart), this%nsend(ipart), MPI_INTEGER, &
329               ipart - 1, 1, mesh%vp%comm, status, mpi_err)
330        end do
331
332        ! we no longer need this
333        SAFE_DEALLOCATE_A(recv_rem_points)
334
335        call MPI_Buffer_detach(send_buffer(1), bsize, mpi_err)
336        SAFE_DEALLOCATE_A(send_buffer)
337
338      end if
339#endif
340
341      if(accel_is_enabled()) then
342        call accel_create_buffer(this%buff_per_points, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, 2*this%nper)
343        call accel_write_buffer(this%buff_per_points, 2*this%nper, this%per_points)
344
345        if(mesh%parallel_in_domains) then
346          call accel_create_buffer(this%buff_per_send, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, product(ubound(this%per_send)))
347          call accel_write_buffer(this%buff_per_send, product(ubound(this%per_send)), this%per_send)
348
349          call accel_create_buffer(this%buff_per_recv, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, product(ubound(this%per_recv)))
350          call accel_write_buffer(this%buff_per_recv, product(ubound(this%per_recv)), this%per_recv)
351
352          call accel_create_buffer(this%buff_nsend, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, mesh%vp%npart)
353          call accel_write_buffer(this%buff_nsend, mesh%vp%npart, this%nsend)
354
355          call accel_create_buffer(this%buff_nrecv, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, mesh%vp%npart)
356          call accel_write_buffer(this%buff_nrecv, mesh%vp%npart, this%nrecv)
357        end if
358      end if
359
360    end if
361
362    POP_SUB(boundaries_init)
363  end subroutine boundaries_init
364
365  ! ---------------------------------------------------------
366
367  subroutine boundaries_end(this)
368    type(boundaries_t),  intent(inout) :: this
369
370    PUSH_SUB(boundaries_end)
371
372    if(simul_box_is_periodic(this%mesh%sb)) then
373      if(this%mesh%parallel_in_domains) then
374
375        ASSERT(associated(this%nsend))
376        ASSERT(associated(this%nrecv))
377
378        SAFE_DEALLOCATE_P(this%per_send)
379        SAFE_DEALLOCATE_P(this%per_recv)
380        SAFE_DEALLOCATE_P(this%nsend)
381        SAFE_DEALLOCATE_P(this%nrecv)
382
383        if(accel_is_enabled()) then
384          call accel_release_buffer(this%buff_per_send)
385          call accel_release_buffer(this%buff_per_recv)
386          call accel_release_buffer(this%buff_nsend)
387          call accel_release_buffer(this%buff_nrecv)
388        end if
389      end if
390
391      if(accel_is_enabled()) call accel_release_buffer(this%buff_per_points)
392
393      SAFE_DEALLOCATE_P(this%per_points)
394    end if
395
396    POP_SUB(boundaries_end)
397  end subroutine boundaries_end
398
399  ! -------------------------------------------------------
400
401  subroutine boundaries_set_batch(this, ffb, phase_correction)
402    type(boundaries_t), intent(in)    :: this
403    class(batch_t),     intent(inout) :: ffb
404    CMPLX, optional,    intent(in)    :: phase_correction(:)
405
406    PUSH_SUB(boundaries_set_batch)
407
408    if(ffb%type() == TYPE_FLOAT) then
409      call dboundaries_set_batch(this, ffb, phase_correction)
410    else if(ffb%type() == TYPE_CMPLX) then
411      call zboundaries_set_batch(this, ffb, phase_correction)
412    else
413      ASSERT(.false.)
414     end if
415
416     POP_SUB(boundaries_set_batch)
417   end subroutine boundaries_set_batch
418
419#include "undef.F90"
420#include "complex.F90"
421#include "boundaries_inc.F90"
422
423#include "undef.F90"
424#include "real.F90"
425#include "boundaries_inc.F90"
426
427end module boundaries_oct_m
428
429!! Local Variables:
430!! mode: f90
431!! coding: utf-8
432!! End:
433