1!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
2!!
3!! This program is free software; you can redistribute it and/or modify
4!! it under the terms of the GNU General Public License as published by
5!! the Free Software Foundation; either version 2, or (at your option)
6!! any later version.
7!!
8!! This program is distributed in the hope that it will be useful,
9!! but WITHOUT ANY WARRANTY; without even the implied warranty of
10!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11!! GNU General Public License for more details.
12!!
13!! You should have received a copy of the GNU General Public License
14!! along with this program; if not, write to the Free Software
15!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
16!! 02110-1301, USA.
17!!
18
19#include "global.h"
20
21module nl_operator_oct_m
22  use accel_oct_m
23  use batch_oct_m
24  use boundaries_oct_m
25  use global_oct_m
26  use index_oct_m
27  use iso_c_binding
28  use loct_pointer_oct_m
29  use math_oct_m
30  use mesh_oct_m
31  use messages_oct_m
32  use mpi_oct_m
33  use multicomm_oct_m
34  use namespace_oct_m
35  use operate_f_oct_m
36  use par_vec_oct_m
37  use parser_oct_m
38  use profiling_oct_m
39  use simul_box_oct_m
40  use stencil_oct_m
41  use types_oct_m
42  use varinfo_oct_m
43
44  implicit none
45
46  private
47  public ::                     &
48    nl_operator_t,              &
49    nl_operator_index_t,        &
50    nl_operator_global_init,    &
51    nl_operator_global_end,     &
52    nl_operator_init,           &
53    nl_operator_copy,           &
54    nl_operator_build,          &
55    dnl_operator_operate,       &
56    znl_operator_operate,       &
57    dnl_operator_operate_batch, &
58    znl_operator_operate_batch, &
59    dnl_operator_operate_diag,  &
60    znl_operator_operate_diag,  &
61    nl_operator_end,            &
62    nl_operator_skewadjoint,    &
63    nl_operator_selfadjoint,    &
64    nl_operator_get_index,      &
65    nl_operator_update_weights, &
66    nl_operator_np_zero_bc,         &
67    nl_operator_compact_boundaries
68
69  type nl_operator_index_t
70    private
71    integer          :: nri
72    integer, pointer :: imin(:)
73    integer, pointer :: imax(:)
74    integer, pointer :: ri(:, :)
75  end type nl_operator_index_t
76
77  type nl_operator_t
78    private
79    type(stencil_t),  public :: stencil
80    type(mesh_t), pointer    :: mesh      !< pointer to the underlying mesh
81    integer, pointer         :: nn(:)     !< the size of the stencil at each point (for curvilinear coordinates)
82    integer,          public :: np        !< number of points in mesh
83    !> When running in parallel mode, the next three arrays are unique on each node.
84    integer, pointer, public :: index(:,:)    !< index of the points. Unique on each parallel process.
85    FLOAT,   pointer, public :: w(:,:)        !< weights. Unique on each parallel process.
86
87    logical,          public :: const_w   !< are the weights independent of index i
88
89    character(len=40) :: label
90
91    !> the compressed index of grid points
92    integer, public :: nri
93    integer, pointer, public :: ri(:,:)
94    integer, pointer, public :: rimap(:)
95    integer, pointer, public :: rimap_inv(:)
96
97    integer                   :: ninner
98    integer                   :: nouter
99
100    type(nl_operator_index_t) :: inner
101    type(nl_operator_index_t) :: outer
102
103    type(accel_kernel_t) :: kernel
104    type(accel_mem_t) :: buff_imin
105    type(accel_mem_t) :: buff_imax
106    type(accel_mem_t) :: buff_ri
107    type(accel_mem_t) :: buff_map
108    type(accel_mem_t) :: buff_all
109    type(accel_mem_t) :: buff_inner
110    type(accel_mem_t) :: buff_outer
111    type(accel_mem_t) :: buff_stencil
112    type(accel_mem_t) :: buff_ip_to_xyz
113    type(accel_mem_t) :: buff_xyz_to_ip
114  end type nl_operator_t
115
116  integer, parameter :: &
117       OP_FORTRAN = 0,  &
118       OP_VEC     = 1,  &
119       OP_MIN     = OP_FORTRAN, &
120       OP_MAX     = OP_VEC
121
122  integer, parameter ::     &
123    OP_INVMAP    = 1,       &
124    OP_MAP       = 2,       &
125    OP_NOMAP     = 3
126
127  integer, public, parameter :: OP_ALL = 3, OP_INNER = 1, OP_OUTER = 2
128
129  logical :: compact_boundaries
130
131  interface
132    integer function op_is_available(opid, type)
133      implicit none
134      integer, intent(in) :: opid, type
135    end function op_is_available
136  end interface
137
138  integer :: dfunction_global = -1
139  integer :: zfunction_global = -1
140  integer :: sfunction_global = -1
141  integer :: cfunction_global = -1
142  integer :: function_opencl
143
144  type(profile_t), save :: operate_batch_prof
145
146contains
147
148  ! ---------------------------------------------------------
149  subroutine nl_operator_global_init(namespace)
150    type(namespace_t),         intent(in)    :: namespace
151
152    integer :: default
153
154    PUSH_SUB(nl_operator_global_init)
155
156    !%Variable OperateDouble
157    !%Type integer
158    !%Section Execution::Optimization
159    !%Default optimized
160    !%Description
161    !% This variable selects the subroutine used to apply non-local
162    !% operators over the grid for real functions.
163    !%Option fortran 0
164    !% The standard Fortran function.
165    !%Option optimized 1
166    !% This version is optimized using vector primitives (if available).
167    !%End
168
169    !%Variable OperateComplex
170    !%Type integer
171    !%Section Execution::Optimization
172    !%Default optimized
173    !%Description
174    !% This variable selects the subroutine used to apply non-local
175    !% operators over the grid for complex functions.
176    !%Option fortran 0
177    !% The standard Fortran function.
178    !%Option optimized 1
179    !% This version is optimized using vector primitives (if available).
180    !%End
181
182    default = OP_VEC
183
184    call parse_variable(namespace, 'OperateDouble', default, dfunction_global)
185    if(.not.varinfo_valid_option('OperateDouble', dfunction_global)) call messages_input_error(namespace, 'OperateDouble')
186
187    call parse_variable(namespace, 'OperateComplex', default, zfunction_global)
188    if(.not.varinfo_valid_option('OperateComplex', zfunction_global)) call messages_input_error(namespace, 'OperateComplex')
189
190
191    !%Variable OperateSingle
192    !%Type integer
193    !%Section Execution::Optimization
194    !%Default optimized
195    !%Description
196    !% This variable selects the subroutine used to apply non-local
197    !% operators over the grid for single-precision real functions.
198    !%Option fortran 0
199    !% The standard Fortran function.
200    !%Option optimized 1
201    !% This version is optimized using vector primitives (if available).
202    !%End
203
204    !%Variable OperateComplexSingle
205    !%Type integer
206    !%Section Execution::Optimization
207    !%Default optimized
208    !%Description
209    !% This variable selects the subroutine used to apply non-local
210    !% operators over the grid for single-precision complex functions.
211    !%Option fortran 0
212    !% The standard Fortran function.
213    !%Option optimized 1
214    !% This version is optimized using vector primitives (if available).
215    !%End
216
217    call parse_variable(namespace, 'OperateSingle', OP_FORTRAN, sfunction_global)
218    if(.not.varinfo_valid_option('OperateSingle', sfunction_global)) call messages_input_error(namespace, 'OperateSingle')
219
220    call parse_variable(namespace, 'OperateComplexSingle', OP_FORTRAN, cfunction_global)
221    if(.not.varinfo_valid_option('OperateComplexSingle', cfunction_global)) then
222      call messages_input_error(namespace, 'OperateComplexSingle')
223    end if
224
225    if(accel_is_enabled()) then
226
227      !%Variable OperateAccel
228      !%Type integer
229      !%Default map
230      !%Section Execution::Optimization
231      !%Description
232      !% This variable selects the subroutine used to apply non-local
233      !% operators over the grid when an accelerator device is used.
234      !%Option invmap 1
235      !% The standard implementation ported to OpenCL.
236      !%Option map 2
237      !% A different version, more suitable for GPUs.
238      !%Option nomap 3
239      !% (Experimental) This version does not use a map.
240      !%End
241      call parse_variable(namespace, 'OperateAccel',  OP_MAP, function_opencl)
242
243      call messages_obsolete_variable(namespace, 'OperateOpenCL', 'OperateAccel')
244
245    end if
246
247    !%Variable NLOperatorCompactBoundaries
248    !%Type logical
249    !%Default no
250    !%Section Execution::Optimization
251    !%Description
252    !% (Experimental) When set to yes, for finite systems Octopus will
253    !% map boundary points for finite-differences operators to a few
254    !% memory locations. This increases performance, however it is
255    !% experimental and has not been thoroughly tested.
256    !%End
257
258    call parse_variable(namespace, 'NLOperatorCompactBoundaries', .false., compact_boundaries)
259
260    if(compact_boundaries) then
261      call messages_experimental('NLOperatorCompactBoundaries')
262    end if
263
264    POP_SUB(nl_operator_global_init)
265  end subroutine nl_operator_global_init
266
267  ! ---------------------------------------------------------
268
269  subroutine nl_operator_global_end()
270    PUSH_SUB(nl_operator_global_end)
271
272    POP_SUB(nl_operator_global_end)
273  end subroutine nl_operator_global_end
274
275  ! ---------------------------------------------------------
276  subroutine nl_operator_init(op, label)
277    type(nl_operator_t), intent(out) :: op
278    character(len=*),    intent(in)  :: label
279
280    PUSH_SUB(nl_operator_init)
281
282    nullify(op%mesh, op%index, op%w, op%ri, op%rimap, op%rimap_inv)
283    nullify(op%inner%imin, op%inner%imax, op%inner%ri)
284    nullify(op%outer%imin, op%outer%imax, op%outer%ri)
285    nullify(op%nn)
286
287    op%label = label
288
289    call accel_mem_nullify(op%buff_imin)
290    call accel_mem_nullify(op%buff_imax)
291    call accel_mem_nullify(op%buff_ri)
292    call accel_mem_nullify(op%buff_map)
293    call accel_mem_nullify(op%buff_all)
294    call accel_mem_nullify(op%buff_inner)
295    call accel_mem_nullify(op%buff_outer)
296    call accel_mem_nullify(op%buff_stencil)
297    call accel_mem_nullify(op%buff_ip_to_xyz)
298    call accel_mem_nullify(op%buff_xyz_to_ip)
299
300    POP_SUB(nl_operator_init)
301  end subroutine nl_operator_init
302
303
304  ! ---------------------------------------------------------
305  subroutine nl_operator_copy(opo, opi)
306    type(nl_operator_t),         intent(out) :: opo
307    type(nl_operator_t), target, intent(in)  :: opi
308
309    PUSH_SUB(nl_operator_copy)
310
311    ! We cannot currently copy the GPU kernel for the nl_operator
312    ASSERT(.not. accel_is_enabled())
313
314    call nl_operator_init(opo, opi%label)
315
316    call stencil_copy(opi%stencil, opo%stencil)
317
318    opo%np           =  opi%np
319    opo%mesh         => opi%mesh
320
321    call loct_pointer_copy(opo%nn, opi%nn)
322    call loct_pointer_copy(opo%index, opi%index)
323    call loct_pointer_copy(opo%w, opi%w)
324
325    opo%const_w   = opi%const_w
326
327    opo%nri       =  opi%nri
328    ASSERT(associated(opi%ri))
329
330    call loct_pointer_copy(opo%ri, opi%ri)
331    call loct_pointer_copy(opo%rimap, opi%rimap)
332    call loct_pointer_copy(opo%rimap_inv, opi%rimap_inv)
333
334    if(opi%mesh%parallel_in_domains) then
335      opo%inner%nri = opi%inner%nri
336      call loct_pointer_copy(opo%inner%imin, opi%inner%imin)
337      call loct_pointer_copy(opo%inner%imax, opi%inner%imax)
338      call loct_pointer_copy(opo%inner%ri,   opi%inner%ri)
339
340      opo%outer%nri = opi%outer%nri
341      call loct_pointer_copy(opo%outer%imin, opi%outer%imin)
342      call loct_pointer_copy(opo%outer%imax, opi%outer%imax)
343      call loct_pointer_copy(opo%outer%ri,   opi%outer%ri)
344    end if
345
346
347    POP_SUB(nl_operator_copy)
348  end subroutine nl_operator_copy
349
350
351  ! ---------------------------------------------------------
352  subroutine nl_operator_build(mesh, op, np, const_w)
353    type(mesh_t), target, intent(in)    :: mesh
354    type(nl_operator_t),  intent(inout) :: op
355    integer,              intent(in)    :: np       !< Number of (local) points.
356    logical, optional,    intent(in)    :: const_w  !< are the weights constant (independent of the point)
357
358    integer :: ii, jj, p1(MAX_DIM), time, current, size
359    integer, allocatable :: st1(:), st2(:), st1r(:), stencil(:, :)
360    integer :: nn
361    integer :: ir, maxp, iinner, iouter
362    logical :: change, force_change
363    character(len=200) :: flags
364    integer, allocatable :: inner_points(:), outer_points(:), all_points(:)
365
366    PUSH_SUB(nl_operator_build)
367
368    if(mesh%parallel_in_domains .and. .not. const_w) then
369      call messages_experimental('Domain parallelization with curvilinear coordinates')
370    end if
371
372    ASSERT(np > 0)
373
374    ! store values in structure
375    op%np       = np
376    op%mesh     => mesh
377    op%const_w  = .false.
378    if(present(const_w )) op%const_w  = const_w
379
380    ! allocate weights op%w
381    if(op%const_w) then
382      SAFE_ALLOCATE(op%w(1:op%stencil%size, 1:1))
383      if(debug%info) then
384        message(1) = 'Info: nl_operator_build: working with constant weights.'
385        call messages_info(1)
386      end if
387    else
388      SAFE_ALLOCATE(op%w(1:op%stencil%size, 1:op%np))
389      if(debug%info) then
390        message(1) = 'Info: nl_operator_build: working with non-constant weights.'
391        call messages_info(1)
392      end if
393    end if
394
395    ! set initially to zero
396    op%w = M_ZERO
397
398    ! Build lookup table
399    SAFE_ALLOCATE(st1(1:op%stencil%size))
400    SAFE_ALLOCATE(st1r(1:op%stencil%size))
401    SAFE_ALLOCATE(st2(1:op%stencil%size))
402
403    op%nri = 0
404    do time = 1, 2
405      st2 = 0
406      do ii = 1, np
407        p1 = 0
408        if(mesh%parallel_in_domains) then
409          ! When running in parallel, get global number of
410          ! point ii.
411          call index_to_coords(mesh%idx, &
412            mesh%vp%local(mesh%vp%xlocal + ii - 1), p1)
413        else
414          call index_to_coords(mesh%idx, ii, p1)
415        end if
416
417        do jj = 1, op%stencil%size
418          ! Get global index of p1 plus current stencil point.
419          if(mesh%sb%mr_flag) then
420            st1(jj) = index_from_coords(mesh%idx, &
421                 p1(1:MAX_DIM) + mesh%resolution(p1(1), p1(2), p1(3))*op%stencil%points(1:MAX_DIM, jj))
422          else
423            st1(jj) = index_from_coords(mesh%idx, p1(1:MAX_DIM) + op%stencil%points(1:MAX_DIM, jj))
424          end if
425
426          if(mesh%parallel_in_domains) then
427            ! When running parallel, translate this global
428            ! number back to a local number.
429            st1(jj) = vec_global2local(mesh%vp, st1(jj), mesh%vp%partno)
430          end if
431
432          ! if boundary conditions are zero, we can remap boundary
433          ! points to reduce memory accesses. We cannot do this for the
434          ! first point, since it is used to build the weights, so it
435          ! has to have the positions right
436          if(ii > 1 .and. compact_boundaries .and. mesh_compact_boundaries(mesh)) then
437            st1(jj) = min(st1(jj), mesh%np + 1)
438          end if
439          ASSERT(st1(jj) > 0)
440        end do
441
442        st1(1:op%stencil%size) = st1(1:op%stencil%size) - ii
443
444        change = any(st1 /= st2)
445
446        !the next is to detect when we move from a point that does not
447        !have boundary points as neighbours to one that has
448        force_change = any(st1 + ii > mesh%np) .and. all(st2 + ii - 1 <= mesh%np)
449
450        if(change .and. compact_boundaries .and. mesh_compact_boundaries(mesh)) then
451          !try to repair it by changing the boundary points
452          do jj = 1, op%stencil%size
453            if(st1(jj) + ii > mesh%np .and. st2(jj) + ii - 1 > mesh%np .and. st2(jj) + ii <= mesh%np_part) then
454              st1r(jj) = st2(jj)
455            else
456              st1r(jj) = st1(jj)
457            end if
458          end do
459
460          change = any(st1r /= st2)
461
462          if(.not. change) st1 = st1r
463        end if
464
465        ! if the stencil changes
466        if (change .or. force_change) then
467          !store it
468          st2 = st1
469
470          !first time, just count
471          if ( time == 1 ) op%nri = op%nri + 1
472
473          !second time, store
474          if( time == 2 ) then
475            current = current + 1
476            op%ri(1:op%stencil%size, current) = st1(1:op%stencil%size)
477          end if
478        end if
479
480        if(time == 2) op%rimap(ii) = current
481
482      end do
483
484      !after counting, allocate
485      if (time == 1 ) then
486        SAFE_ALLOCATE(op%ri(1:op%stencil%size, 1:op%nri))
487        SAFE_ALLOCATE(op%rimap(1:op%np))
488        SAFE_ALLOCATE(op%rimap_inv(1:op%nri + 1))
489        op%ri        = 0
490        op%rimap     = 0
491        op%rimap_inv = 0
492        current      = 0
493
494        ! the sizes
495        if(mesh%use_curvilinear) then
496          SAFE_ALLOCATE(op%nn(1:op%nri))
497          ! for the moment all the sizes are the same
498          op%nn = op%stencil%size
499        end if
500      end if
501
502    end do
503
504    !the inverse mapping
505    op%rimap_inv(1) = 0
506    do jj = 1, op%np
507      op%rimap_inv(op%rimap(jj) + 1) = jj
508    end do
509    op%rimap_inv(op%nri + 1) = op%np
510
511    SAFE_DEALLOCATE_A(st1)
512    SAFE_DEALLOCATE_A(st1r)
513    SAFE_DEALLOCATE_A(st2)
514
515    do jj = 1, op%nri
516      nn = op%rimap_inv(jj + 1) - op%rimap_inv(jj)
517    end do
518
519    if(op%mesh%parallel_in_domains) then
520      !now build the arrays required to apply the nl_operator by parts
521
522      !count points
523      op%inner%nri = 0
524      op%outer%nri = 0
525      do ir = 1, op%nri
526        maxp = op%rimap_inv(ir + 1) + maxval(op%ri(1:op%stencil%size, ir))
527        if (maxp <= np) then
528          !inner point
529          op%inner%nri = op%inner%nri + 1
530          ASSERT(op%inner%nri <= op%nri)
531        else
532          !outer point
533          op%outer%nri = op%outer%nri + 1
534          ASSERT(op%outer%nri <= op%nri)
535        end if
536      end do
537
538      ASSERT(op%inner%nri + op%outer%nri == op%nri)
539
540      SAFE_ALLOCATE(op%inner%imin(1:op%inner%nri + 1))
541      SAFE_ALLOCATE(op%inner%imax(1:op%inner%nri))
542      SAFE_ALLOCATE(op%inner%ri(1:op%stencil%size, 1:op%inner%nri))
543
544      SAFE_ALLOCATE(op%outer%imin(1:op%outer%nri + 1))
545      SAFE_ALLOCATE(op%outer%imax(1:op%outer%nri))
546      SAFE_ALLOCATE(op%outer%ri(1:op%stencil%size, 1:op%outer%nri))
547
548      !now populate the arrays
549      iinner = 0
550      iouter = 0
551      do ir = 1, op%nri
552        maxp = op%rimap_inv(ir + 1) + maxval(op%ri(1:op%stencil%size, ir))
553        if (maxp <= np) then
554          !inner point
555          iinner = iinner + 1
556          op%inner%imin(iinner) = op%rimap_inv(ir)
557          op%inner%imax(iinner) = op%rimap_inv(ir + 1)
558          op%inner%ri(1:op%stencil%size, iinner) = op%ri(1:op%stencil%size, ir)
559        else
560          !outer point
561          iouter = iouter + 1
562          op%outer%imin(iouter) = op%rimap_inv(ir)
563          op%outer%imax(iouter) = op%rimap_inv(ir + 1)
564          op%outer%ri(1:op%stencil%size, iouter) = op%ri(1:op%stencil%size, ir)
565        end if
566      end do
567
568      !verify that all points in the inner operator are actually inner
569      do ir = 1, op%inner%nri
570        do ii = op%inner%imin(ir) + 1, op%inner%imax(ir)
571          ASSERT(all(ii + op%inner%ri(1:op%stencil%size, ir) <= mesh%np))
572        end do
573      end do
574
575    end if
576
577    if(accel_is_enabled() .and. op%const_w) then
578
579      write(flags, '(i5)') op%stencil%size
580      flags='-DNDIM=3 -DSTENCIL_SIZE='//trim(adjustl(flags))
581
582      if(op%mesh%parallel_in_domains) flags = '-DINDIRECT '//trim(flags)
583
584      select case(function_opencl)
585      case(OP_INVMAP)
586        call accel_kernel_build(op%kernel, 'operate.cl', 'operate', flags)
587      case(OP_MAP)
588        call accel_kernel_build(op%kernel, 'operate.cl', 'operate_map', flags)
589      case(OP_NOMAP)
590        call accel_kernel_build(op%kernel, 'operate.cl', 'operate_nomap', flags)
591      end select
592
593      call accel_create_buffer(op%buff_ri, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, op%nri*op%stencil%size)
594      call accel_write_buffer(op%buff_ri, op%nri*op%stencil%size, op%ri)
595
596      select case(function_opencl)
597      case(OP_INVMAP)
598        call accel_create_buffer(op%buff_imin, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, op%nri)
599        call accel_write_buffer(op%buff_imin, op%nri, op%rimap_inv(1:))
600        call accel_create_buffer(op%buff_imax, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, op%nri)
601        call accel_write_buffer(op%buff_imax, op%nri, op%rimap_inv(2:))
602
603      case(OP_MAP)
604
605        call accel_create_buffer(op%buff_map, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, pad(op%mesh%np, accel_max_workgroup_size()))
606        call accel_write_buffer(op%buff_map, op%mesh%np, (op%rimap - 1)*op%stencil%size)
607
608        if(op%mesh%parallel_in_domains) then
609
610          SAFE_ALLOCATE(inner_points(1:op%mesh%np))
611          SAFE_ALLOCATE(outer_points(1:op%mesh%np))
612          SAFE_ALLOCATE(all_points(1:op%mesh%np))
613
614          op%ninner = 0
615          op%nouter = 0
616
617          do ii = 1, op%mesh%np
618            all_points(ii) = ii - 1
619            maxp = ii + maxval(op%ri(1:op%stencil%size, op%rimap(ii)))
620            if(maxp <= op%mesh%np) then
621              op%ninner = op%ninner + 1
622              inner_points(op%ninner) = ii - 1
623            else
624              op%nouter = op%nouter + 1
625              outer_points(op%nouter) = ii - 1
626            end if
627          end do
628
629          call accel_create_buffer(op%buff_all, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, pad(op%mesh%np, accel_max_workgroup_size()))
630          call accel_write_buffer(op%buff_all, op%mesh%np, all_points)
631
632          call accel_create_buffer(op%buff_inner, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, pad(op%ninner, accel_max_workgroup_size()))
633          call accel_write_buffer(op%buff_inner, op%ninner, inner_points)
634
635          call accel_create_buffer(op%buff_outer, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, pad(op%nouter, accel_max_workgroup_size()))
636          call accel_write_buffer(op%buff_outer, op%nouter, outer_points)
637
638          SAFE_DEALLOCATE_A(inner_points)
639          SAFE_DEALLOCATE_A(outer_points)
640          SAFE_DEALLOCATE_A(all_points)
641
642        end if
643
644      case(OP_NOMAP)
645
646        ASSERT(op%mesh%sb%dim == 3)
647        ASSERT(.not. op%mesh%parallel_in_domains)
648
649        call accel_create_buffer(op%buff_map, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, pad(op%mesh%np, accel_max_workgroup_size()))
650        call accel_write_buffer(op%buff_map, op%mesh%np, (op%rimap - 1)*op%stencil%size)
651
652        SAFE_ALLOCATE(stencil(1:op%mesh%sb%dim, 1:op%stencil%size + 1))
653
654        stencil(1:op%mesh%sb%dim, 1:op%stencil%size) = op%stencil%points(1:op%mesh%sb%dim, 1:op%stencil%size)
655
656        stencil(1, op%stencil%size + 1) = 1
657        stencil(2, op%stencil%size + 1) = mesh%idx%nr(2, 1) - mesh%idx%nr(1, 1) + 1
658        stencil(3, op%stencil%size + 1) = stencil(2, op%stencil%size + 1)*(mesh%idx%nr(2, 2) - mesh%idx%nr(1, 2) + 1)
659
660        call accel_create_buffer(op%buff_stencil, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, op%mesh%sb%dim*(op%stencil%size + 1))
661        call accel_write_buffer(op%buff_stencil, op%mesh%sb%dim*(op%stencil%size + 1), stencil)
662
663        SAFE_DEALLOCATE_A(stencil)
664
665        size = product(mesh%idx%nr(2, 1:op%mesh%sb%dim) - mesh%idx%nr(1, 1:op%mesh%sb%dim) + 1)
666
667        call accel_create_buffer(op%buff_xyz_to_ip, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, size)
668        call accel_write_buffer(op%buff_xyz_to_ip, size, op%mesh%idx%lxyz_inv - 1)
669
670        SAFE_ALLOCATE(stencil(1:op%mesh%sb%dim, 1:mesh%np_part))
671
672        do jj = 1, op%mesh%sb%dim
673          stencil(jj, 1:mesh%np_part) = op%mesh%idx%lxyz(1:mesh%np_part, jj) - mesh%idx%nr(1, jj)
674        end do
675
676        ASSERT(minval(stencil) == 0)
677
678        call accel_create_buffer(op%buff_ip_to_xyz, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, op%mesh%np_part*op%mesh%sb%dim)
679        call accel_write_buffer(op%buff_ip_to_xyz, op%mesh%np_part*op%mesh%sb%dim, stencil)
680
681        SAFE_DEALLOCATE_A(stencil)
682
683      end select
684    end if
685
686    POP_SUB(nl_operator_build)
687
688  end subroutine nl_operator_build
689
690  ! ---------------------------------------------------------
691  subroutine nl_operator_update_weights(this)
692    type(nl_operator_t), intent(inout)  :: this
693
694    integer :: istencil, idir
695
696    PUSH_SUB(nl_operator_update_weights)
697
698    if(debug%info) then
699
700      write(message(1), '(3a)') 'Debug info: Finite difference weights for ', trim(this%label), '.'
701      write(message(2), '(a)')  '            Spacing:'
702      do idir = 1, this%mesh%sb%dim
703        write(message(2), '(a,f16.8)') trim(message(2)), this%mesh%spacing(idir)
704      end do
705      call messages_info(2)
706
707      do istencil = 1, this%stencil%size
708        write(message(1), '(a,i3,3i4,f25.10)') '      ', istencil, this%stencil%points(1:3, istencil), this%w(istencil, 1)
709        call messages_info(1)
710      end do
711
712    end if
713
714    POP_SUB(nl_operator_update_weights)
715
716  end subroutine nl_operator_update_weights
717
718  ! ---------------------------------------------------------
719  !> opt has to be initialised and built.
720  subroutine nl_operator_skewadjoint(op, opt, mesh)
721    type(nl_operator_t), target, intent(in)  :: op
722    type(nl_operator_t), target, intent(out) :: opt
723    type(mesh_t),        target, intent(in)  :: mesh
724
725    integer          :: ip, jp, kp, lp, index
726    FLOAT, pointer   :: vol_pp(:)
727
728    type(nl_operator_t), pointer :: opg, opgt
729
730    PUSH_SUB(nl_operator_skewadjoint)
731
732    call nl_operator_copy(opt, op)
733
734#if defined(HAVE_MPI)
735    if(mesh%parallel_in_domains) then
736      SAFE_ALLOCATE(opg)
737      SAFE_ALLOCATE(opgt)
738      call nl_operator_allgather(op, opg)
739      call nl_operator_init(opgt, op%label)
740      call nl_operator_copy(opgt, opg)
741      SAFE_ALLOCATE(vol_pp(1:mesh%np_global))
742      call vec_allgather(mesh%vp, vol_pp, mesh%vol_pp)
743    else
744#endif
745      opg  => op
746      opgt => opt
747      vol_pp => mesh%vol_pp
748#if defined(HAVE_MPI)
749    end if
750#endif
751
752    opgt%w = M_ZERO
753    do ip = 1, mesh%np_global
754      do jp = 1, op%stencil%size
755        index = nl_operator_get_index(opg, jp, ip)
756        if(index <= mesh%np_global) then
757          do lp = 1, op%stencil%size
758            kp = nl_operator_get_index(opg, lp, index)
759            if( kp == ip ) then
760              if(.not.op%const_w) then
761                opgt%w(jp, ip) = M_HALF*opg%w(jp, ip) - M_HALF*(vol_pp(index)/vol_pp(ip))*opg%w(lp, index)
762              else
763                opgt%w(jp, 1) = opg%w(lp, 1)
764              end if
765            end if
766          end do
767        end if
768      end do
769    end do
770
771#if defined(HAVE_MPI)
772    if(mesh%parallel_in_domains) then
773      SAFE_DEALLOCATE_P(vol_pp)
774      do ip = 1, mesh%vp%np_local
775        opt%w(:, ip) = opgt%w(:, mesh%vp%local(mesh%vp%xlocal+ip-1))
776      end do
777      call nl_operator_end(opg)
778      call nl_operator_end(opgt)
779      SAFE_DEALLOCATE_P(opg)
780      SAFE_DEALLOCATE_P(opgt)
781    end if
782#endif
783
784    POP_SUB(nl_operator_skewadjoint)
785  end subroutine nl_operator_skewadjoint
786
787
788  ! ---------------------------------------------------------
789  subroutine nl_operator_selfadjoint(op, opt, mesh)
790    type(nl_operator_t), target, intent(in)  :: op
791    type(nl_operator_t), target, intent(out) :: opt
792    type(mesh_t),        target, intent(in)  :: mesh
793
794    integer          :: ip, jp, kp, lp, index
795    FLOAT, pointer   :: vol_pp(:)
796
797    type(nl_operator_t), pointer :: opg, opgt
798
799    PUSH_SUB(nl_operator_selfadjoint)
800
801    call nl_operator_copy(opt, op)
802
803    if(mesh%parallel_in_domains) then
804#if defined(HAVE_MPI)
805      SAFE_ALLOCATE(opg)
806      SAFE_ALLOCATE(opgt)
807      call nl_operator_allgather(op, opg)
808      call nl_operator_init(opgt, op%label)
809      opgt = opg
810      SAFE_ALLOCATE(vol_pp(1:mesh%np_global))
811      call vec_allgather(mesh%vp, vol_pp, mesh%vol_pp)
812#else
813      ! avoid appearance of using vol_pp uninitialized
814      vol_pp => mesh%vol_pp
815      ASSERT(.false.)
816#endif
817    else
818      opg  => op
819      opgt => opt
820      vol_pp => mesh%vol_pp
821    end if
822
823    opgt%w = M_ZERO
824    do ip = 1, mesh%np_global
825      do jp = 1, op%stencil%size
826        index = nl_operator_get_index(opg, jp, ip)
827
828        if(index <= mesh%np_global) then
829          do lp = 1, op%stencil%size
830            kp = nl_operator_get_index(opg, lp, index)
831            if( kp == ip ) then
832              if(.not.op%const_w) then
833                opgt%w(jp, ip) = M_HALF*opg%w(jp, ip) + M_HALF*(vol_pp(index)/vol_pp(ip))*opg%w(lp, index)
834              else
835                opgt%w(jp, 1) = opg%w(lp, 1)
836              end if
837            end if
838          end do
839        end if
840
841      end do
842    end do
843
844#if defined(HAVE_MPI)
845    if(mesh%parallel_in_domains) then
846      SAFE_DEALLOCATE_P(vol_pp)
847      do ip = 1, mesh%vp%np_local
848        opt%w(:, ip) = opgt%w(:, mesh%vp%local(mesh%vp%xlocal+ip-1))
849      end do
850      call nl_operator_end(opg)
851      call nl_operator_end(opgt)
852      SAFE_DEALLOCATE_P(opg)
853      SAFE_DEALLOCATE_P(opgt)
854    end if
855#endif
856
857    POP_SUB(nl_operator_selfadjoint)
858  end subroutine nl_operator_selfadjoint
859
860
861#if defined(HAVE_MPI)
862
863  ! ---------------------------------------------------------
864  !> Like nl_operator_gather but opg is present on all nodes
865  !! (so do not forget to call nl_operator_end on all nodes
866  !! afterwards).
867  subroutine nl_operator_allgather(op, opg)
868    type(nl_operator_t), intent(in)  :: op
869    type(nl_operator_t), intent(out) :: opg
870
871    integer :: ip
872
873    PUSH_SUB(nl_operator_allgather)
874
875    ! Copy elements of op to opg that
876    ! are independent from the partitions, i.e. everything
877    ! except op%index and -- in the non-constant case -- op%w
878    call nl_operator_common_copy(op, opg)
879
880    ! Weights have to be collected only if they are not constant.
881    if(.not.op%const_w) then
882      do ip = 1, op%stencil%size
883        call vec_allgather(op%mesh%vp, opg%w(ip, :), op%w(ip, :))
884      end do
885    end if
886
887    POP_SUB(nl_operator_allgather)
888
889  end subroutine nl_operator_allgather
890
891  ! ---------------------------------------------------------
892  ! The following are private routines.
893  ! ---------------------------------------------------------
894
895  ! ---------------------------------------------------------
896  !> Copies all parts of op to opg that are independent of
897  !! the partitions, i.e. everything except op%index and -- in the
898  !! non-constant case -- op%w
899  !! This can be considered as nl_operator_copy and
900  !! reallocating w and i.
901  !! \warning: this should be replaced by a normal copy with a flag.
902  subroutine nl_operator_common_copy(op, opg)
903    type(nl_operator_t), target, intent(in)  :: op
904    type(nl_operator_t),         intent(out) :: opg
905
906    PUSH_SUB(nl_operator_common_copy)
907
908    call nl_operator_init(opg, op%label)
909
910    call stencil_copy(op%stencil, opg%stencil)
911
912    if(op%const_w) then
913      SAFE_ALLOCATE(opg%w(1:op%stencil%size, 1:1))
914    else
915      SAFE_ALLOCATE(opg%w(1:op%stencil%size, 1:op%mesh%np_global))
916    end if
917    opg%mesh     => op%mesh
918    opg%np       =  op%mesh%np_global
919    opg%const_w  =  op%const_w
920    opg%nri      =  op%nri
921    if(op%const_w) then
922      opg%w = op%w
923    end if
924
925    POP_SUB(nl_operator_common_copy)
926
927  end subroutine nl_operator_common_copy
928
929
930  ! ---------------------------------------------------------
931  ! End of private routines.
932  ! ---------------------------------------------------------
933#endif
934
935  ! ---------------------------------------------------------
936  subroutine nl_operator_end(op)
937    type(nl_operator_t), intent(inout) :: op
938
939    PUSH_SUB(nl_operator_end)
940
941    if(accel_is_enabled() .and. op%const_w) then
942
943      call accel_release_buffer(op%buff_ri)
944      select case(function_opencl)
945      case(OP_INVMAP)
946        call accel_release_buffer(op%buff_imin)
947        call accel_release_buffer(op%buff_imax)
948
949      case(OP_MAP)
950        call accel_release_buffer(op%buff_map)
951        if(op%mesh%parallel_in_domains) then
952          call accel_release_buffer(op%buff_all)
953          call accel_release_buffer(op%buff_inner)
954          call accel_release_buffer(op%buff_outer)
955        end if
956
957      case(OP_NOMAP)
958        call accel_release_buffer(op%buff_map)
959        call accel_release_buffer(op%buff_stencil)
960        call accel_release_buffer(op%buff_xyz_to_ip)
961        call accel_release_buffer(op%buff_ip_to_xyz)
962      end select
963    end if
964
965    if(op%mesh%parallel_in_domains) then
966      SAFE_DEALLOCATE_P(op%inner%imin)
967      SAFE_DEALLOCATE_P(op%inner%imax)
968      SAFE_DEALLOCATE_P(op%inner%ri)
969      SAFE_DEALLOCATE_P(op%outer%imin)
970      SAFE_DEALLOCATE_P(op%outer%imax)
971      SAFE_DEALLOCATE_P(op%outer%ri)
972    end if
973
974    SAFE_DEALLOCATE_P(op%index)
975    SAFE_DEALLOCATE_P(op%w)
976
977    SAFE_DEALLOCATE_P(op%ri)
978    SAFE_DEALLOCATE_P(op%rimap)
979    SAFE_DEALLOCATE_P(op%rimap_inv)
980    SAFE_DEALLOCATE_P(op%nn)
981
982    call stencil_end(op%stencil)
983
984    POP_SUB(nl_operator_end)
985  end subroutine nl_operator_end
986
987
988  ! ---------------------------------------------------------
989  integer pure function nl_operator_get_index(op, is, ip) result(res)
990    type(nl_operator_t), intent(in)   :: op
991    integer,             intent(in)   :: is
992    integer,             intent(in)   :: ip
993
994    res = ip + op%ri(is, op%rimap(ip))
995  end function nl_operator_get_index
996
997  ! ---------------------------------------------------------
998
999  integer pure function nl_operator_np_zero_bc(op) result(np_bc)
1000    type(nl_operator_t), intent(in)   :: op
1001
1002    integer :: jj, ii
1003
1004    np_bc = 0
1005    do jj = 1, op%nri
1006      ii = op%rimap_inv(jj + 1) + maxval(op%ri(1:op%stencil%size, jj))
1007      np_bc = max(np_bc, ii)
1008    end do
1009
1010  end function nl_operator_np_zero_bc
1011
1012  ! ------------------------------------------------------
1013
1014  logical pure function nl_operator_compact_boundaries(op)
1015    type(nl_operator_t), intent(in)   :: op
1016
1017    nl_operator_compact_boundaries = compact_boundaries
1018  end function nl_operator_compact_boundaries
1019
1020
1021#include "undef.F90"
1022#include "real.F90"
1023#include "nl_operator_inc.F90"
1024
1025#include "undef.F90"
1026#include "complex.F90"
1027#include "nl_operator_inc.F90"
1028
1029end module nl_operator_oct_m
1030
1031!! Local Variables:
1032!! mode: f90
1033!! coding: utf-8
1034!! End:
1035