1 // Copyright (c) 2010-2021, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 #include "../quadinterpolator.hpp"
13 #include "../../general/forall.hpp"
14 #include "../../linalg/dtensor.hpp"
15 #include "../../fem/kernels.hpp"
16 #include "../../linalg/kernels.hpp"
17 
18 using namespace mfem;
19 
20 namespace mfem
21 {
22 
23 namespace internal
24 {
25 
26 namespace quadrature_interpolator
27 {
28 
29 template<int T_D1D = 0, int T_Q1D = 0, int MAX_D1D = 0, int MAX_Q1D = 0>
Det2D(const int NE,const double * b,const double * g,const double * x,double * y,const int vdim=1,const int d1d=0,const int q1d=0)30 static void Det2D(const int NE,
31                   const double *b,
32                   const double *g,
33                   const double *x,
34                   double *y,
35                   const int vdim = 1,
36                   const int d1d = 0,
37                   const int q1d = 0)
38 {
39    constexpr int DIM = 2;
40    static constexpr int NBZ = 1;
41 
42    const int D1D = T_D1D ? T_D1D : d1d;
43    const int Q1D = T_Q1D ? T_Q1D : q1d;
44 
45    const auto B = Reshape(b, Q1D, D1D);
46    const auto G = Reshape(g, Q1D, D1D);
47    const auto X = Reshape(x,  D1D, D1D, DIM, NE);
48    auto Y = Reshape(y, Q1D, Q1D, NE);
49 
50    MFEM_FORALL_2D(e, NE, Q1D, Q1D, NBZ,
51    {
52       constexpr int MQ1 = T_Q1D ? T_Q1D : MAX_Q1D;
53       constexpr int MD1 = T_D1D ? T_D1D : MAX_D1D;
54       const int D1D = T_D1D ? T_D1D : d1d;
55       const int Q1D = T_Q1D ? T_Q1D : q1d;
56 
57       MFEM_SHARED double BG[2][MQ1*MD1];
58       MFEM_SHARED double XY[2][NBZ][MD1*MD1];
59       MFEM_SHARED double DQ[4][NBZ][MD1*MQ1];
60       MFEM_SHARED double QQ[4][NBZ][MQ1*MQ1];
61 
62       kernels::internal::LoadX<MD1,NBZ>(e,D1D,X,XY);
63       kernels::internal::LoadBG<MD1,MQ1>(D1D,Q1D,B,G,BG);
64 
65       kernels::internal::GradX<MD1,MQ1,NBZ>(D1D,Q1D,BG,XY,DQ);
66       kernels::internal::GradY<MD1,MQ1,NBZ>(D1D,Q1D,BG,DQ,QQ);
67 
68       MFEM_FOREACH_THREAD(qy,y,Q1D)
69       {
70          MFEM_FOREACH_THREAD(qx,x,Q1D)
71          {
72             double J[4];
73             kernels::internal::PullGrad<MQ1,NBZ>(Q1D,qx,qy,QQ,J);
74             Y(qx,qy,e) = kernels::Det<2>(J);
75          }
76       }
77    });
78 }
79 
80 template<int T_D1D = 0, int T_Q1D = 0, int MAX_D1D = 0, int MAX_Q1D = 0,
81          bool SMEM = true>
Det3D(const int NE,const double * b,const double * g,const double * x,double * y,const int vdim=1,const int d1d=0,const int q1d=0,Vector * d_buff=nullptr)82 static void Det3D(const int NE,
83                   const double *b,
84                   const double *g,
85                   const double *x,
86                   double *y,
87                   const int vdim = 1,
88                   const int d1d = 0,
89                   const int q1d = 0,
90                   Vector *d_buff = nullptr) // used only with SMEM = false
91 {
92    constexpr int DIM = 3;
93    static constexpr int MQ1 = T_Q1D ? T_Q1D : MAX_Q1D;
94    static constexpr int MD1 = T_D1D ? T_D1D : MAX_D1D;
95    static constexpr int MDQ = MQ1 > MD1 ? MQ1 : MD1;
96    static constexpr int MSZ = MDQ * MDQ * MDQ * 9;
97    static constexpr int GRID = SMEM ? 0 : 128;
98 
99    const int D1D = T_D1D ? T_D1D : d1d;
100    const int Q1D = T_Q1D ? T_Q1D : q1d;
101 
102    const auto B = Reshape(b, Q1D, D1D);
103    const auto G = Reshape(g, Q1D, D1D);
104    const auto X = Reshape(x, D1D, D1D, D1D, DIM, NE);
105    auto Y = Reshape(y, Q1D, Q1D, Q1D, NE);
106 
107    double *GM = nullptr;
108    if (!SMEM)
109    {
110       d_buff->SetSize(2*MSZ*GRID);
111       GM = d_buff->Write();
112    }
113 
114    MFEM_FORALL_3D_GRID(e, NE, Q1D, Q1D, Q1D, GRID,
115    {
116       const int bid = MFEM_BLOCK_ID(x);
117       MFEM_SHARED double BG[2][MQ1*MD1];
118       MFEM_SHARED double SM0[SMEM?MSZ:1];
119       MFEM_SHARED double SM1[SMEM?MSZ:1];
120       double *lm0 = SMEM ? SM0 : GM + MSZ*bid;
121       double *lm1 = SMEM ? SM1 : GM + MSZ*(GRID+bid);
122       double (*DDD)[MD1*MD1*MD1] = (double (*)[MD1*MD1*MD1]) (lm0);
123       double (*DDQ)[MD1*MD1*MQ1] = (double (*)[MD1*MD1*MQ1]) (lm1);
124       double (*DQQ)[MD1*MQ1*MQ1] = (double (*)[MD1*MQ1*MQ1]) (lm0);
125       double (*QQQ)[MQ1*MQ1*MQ1] = (double (*)[MQ1*MQ1*MQ1]) (lm1);
126 
127       kernels::internal::LoadX<MD1>(e,D1D,X,DDD);
128       kernels::internal::LoadBG<MD1,MQ1>(D1D,Q1D,B,G,BG);
129 
130       kernels::internal::GradX<MD1,MQ1>(D1D,Q1D,BG,DDD,DDQ);
131       kernels::internal::GradY<MD1,MQ1>(D1D,Q1D,BG,DDQ,DQQ);
132       kernels::internal::GradZ<MD1,MQ1>(D1D,Q1D,BG,DQQ,QQQ);
133 
134       MFEM_FOREACH_THREAD(qz,z,Q1D)
135       {
136          MFEM_FOREACH_THREAD(qy,y,Q1D)
137          {
138             MFEM_FOREACH_THREAD(qx,x,Q1D)
139             {
140                double J[9];
141                kernels::internal::PullGrad<MQ1>(Q1D, qx,qy,qz, QQQ, J);
142                Y(qx,qy,qz,e) = kernels::Det<3>(J);
143             }
144          }
145       }
146    });
147 }
148 
149 // Tensor-product evaluation of quadrature point determinants: dispatch
150 // function.
TensorDeterminants(const int NE,const int vdim,const DofToQuad & maps,const Vector & e_vec,Vector & q_det,Vector & d_buff)151 void TensorDeterminants(const int NE,
152                         const int vdim,
153                         const DofToQuad &maps,
154                         const Vector &e_vec,
155                         Vector &q_det,
156                         Vector &d_buff)
157 {
158    if (NE == 0) { return; }
159    const int dim = maps.FE->GetDim();
160    const int D1D = maps.ndof;
161    const int Q1D = maps.nqpt;
162    const double *B = maps.B.Read();
163    const double *G = maps.G.Read();
164    const double *X = e_vec.Read();
165    double *Y = q_det.Write();
166 
167    const int id = (vdim<<8) | (D1D<<4) | Q1D;
168 
169    if (dim == 2)
170    {
171       switch (id)
172       {
173          case 0x222: return Det2D<2,2>(NE,B,G,X,Y);
174          case 0x223: return Det2D<2,3>(NE,B,G,X,Y);
175          case 0x224: return Det2D<2,4>(NE,B,G,X,Y);
176          case 0x226: return Det2D<2,6>(NE,B,G,X,Y);
177          case 0x234: return Det2D<3,4>(NE,B,G,X,Y);
178          case 0x236: return Det2D<3,6>(NE,B,G,X,Y);
179          case 0x244: return Det2D<4,4>(NE,B,G,X,Y);
180          case 0x246: return Det2D<4,6>(NE,B,G,X,Y);
181          case 0x256: return Det2D<5,6>(NE,B,G,X,Y);
182          default:
183          {
184             constexpr int MD = MAX_D1D;
185             constexpr int MQ = MAX_Q1D;
186             MFEM_VERIFY(D1D <= MD, "Orders higher than " << MD-1
187                         << " are not supported!");
188             MFEM_VERIFY(Q1D <= MQ, "Quadrature rules with more than "
189                         << MQ << " 1D points are not supported!");
190             Det2D<0,0,MD,MQ>(NE,B,G,X,Y,vdim,D1D,Q1D);
191             return;
192          }
193       }
194    }
195    if (dim == 3)
196    {
197       switch (id)
198       {
199          case 0x324: return Det3D<2,4>(NE,B,G,X,Y);
200          case 0x333: return Det3D<3,3>(NE,B,G,X,Y);
201          case 0x335: return Det3D<3,5>(NE,B,G,X,Y);
202          case 0x336: return Det3D<3,6>(NE,B,G,X,Y);
203          default:
204          {
205             constexpr int MD = 6;
206             constexpr int MQ = 6;
207             // Highest orders that fit in shared mememory
208             if (D1D <= MD && Q1D <= MQ)
209             { return Det3D<0,0,MD,MQ>(NE,B,G,X,Y,vdim,D1D,Q1D); }
210             // Last fall-back will use global memory
211             return Det3D<0,0,MAX_D1D,MAX_Q1D,false>(
212                       NE,B,G,X,Y,vdim,D1D,Q1D,&d_buff);
213          }
214       }
215    }
216    MFEM_ABORT("Kernel " << std::hex << id << std::dec << " not supported yet");
217 }
218 
219 } // namespace quadrature_interpolator
220 
221 } // namespace internal
222 
223 } // namespace mfem
224