1 // -*- tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*-
2 // vi: set et ts=4 sw=2 sts=2:
3 #ifndef DUNE_FUNCTIONS_FUNCTIONSPACEBASES_TEST_BASISTEST_HH
4 #define DUNE_FUNCTIONS_FUNCTIONSPACEBASES_TEST_BASISTEST_HH
5 
6 #include <set>
7 #include <algorithm>
8 #include <string>
9 #include <sstream>
10 
11 #include <dune/common/test/testsuite.hh>
12 #include <dune/common/concept.hh>
13 #include <dune/common/typetraits.hh>
14 #include <dune/common/hybridutilities.hh>
15 
16 #include <dune/geometry/quadraturerules.hh>
17 
18 #include <dune/functions/functionspacebases/concepts.hh>
19 
20 struct CheckBasisFlag {};
21 struct AllowZeroBasisFunctions {};
22 
23 template<class T, class... S>
24 struct IsContained : public std::disjunction<std::is_same<T,S>...>
25 {};
26 
27 
28 
29 /*
30  * Get string identifier of element
31  */
32 template<class Element, class GridView>
elementStr(const Element & element,const GridView & gridView)33 std::string elementStr(const Element& element, const GridView& gridView)
34 {
35   std::stringstream s;
36   s << element.type() << "#" << gridView.indexSet().index(element);
37   return s.str();
38 }
39 
40 /*
41  * Check if two multi-indices are consecutive.
42  * This is a used by checkBasisIndexTreeConsistency()
43  */
44 template<class MultiIndex>
multiIndicesConsecutive(const MultiIndex & a,const MultiIndex & b)45 bool multiIndicesConsecutive(const MultiIndex& a, const MultiIndex& b)
46 {
47   std::size_t i = 0;
48 
49   // find largest common prefix
50   for (; (i<a.size()) and (i<b.size()) and (a[i] == b[i]); ++i)
51   {};
52 
53   // if b is exhausted but a is not, then b is a strict prefix of a and does not succeed a
54   if ((i<a.size()) and (i==b.size()))
55     return false;
56 
57   // if a and b are not exhausted, then the first non-common index must be an increment
58   if ((i<a.size()) and (i<b.size()))
59   {
60     if (b[i] != a[i]+1)
61       return false;
62     ++i;
63   }
64 
65   // if b is not exhausted, then the following indices should be zero
66   if (i<b.size())
67   {
68     for (; i<b.size(); ++i)
69     {
70       if (b[i] != 0)
71         return false;
72     }
73   }
74   return true;
75 }
76 
77 
78 
79 /*
80  * Check if given set of multi-indices is consistent, i.e.,
81  * if it induces a consistent ordered tree. This is used
82  * by checkBasisIndices()
83  */
84 template<class MultiIndexSet>
checkBasisIndexTreeConsistency(const MultiIndexSet & multiIndexSet)85 Dune::TestSuite checkBasisIndexTreeConsistency(const MultiIndexSet& multiIndexSet)
86 {
87   Dune::TestSuite test("index tree consistency check");
88 
89   using namespace Dune;
90 
91   auto it = multiIndexSet.begin();
92   auto end = multiIndexSet.end();
93 
94   // get first multi-index
95   auto lastMultiIndex = *it;
96 
97   // assert that index is non-empty
98   test.require(lastMultiIndex.size()>0, "multi-index size check")
99     << "empty multi-index found";
100 
101   // check if first multi-index is [0,...,0]
102   for (decltype(lastMultiIndex.size()) i = 0; i<lastMultiIndex.size(); ++i)
103   {
104     test.require(lastMultiIndex[i] == 0, "smallest index check")
105       << "smallest index contains non-zero entry " << lastMultiIndex[i] << " in position " << i;
106   }
107 
108   ++it;
109   for(; it != end; ++it)
110   {
111     auto multiIndex = *it;
112 
113     // assert that index is non-empty
114     test.require(multiIndex.size()>0, "multi-index size check")
115       << "empty multi-index found";
116 
117     // assert that indices are consecutive
118     test.check(multiIndicesConsecutive(lastMultiIndex, multiIndex), "consecutive index check")
119       << "multi-indices " << lastMultiIndex << " and " << multiIndex << " are subsequent but not consecutive";
120 
121     lastMultiIndex = multiIndex;
122   }
123 
124   return test;
125 }
126 
127 
128 
129 /*
130  * Check consistency of basis.size(prefix)
131  */
132 template<class Basis, class MultiIndexSet>
checkBasisSizeConsistency(const Basis & basis,const MultiIndexSet & multiIndexSet)133 Dune::TestSuite checkBasisSizeConsistency(const Basis& basis, const MultiIndexSet& multiIndexSet)
134 {
135   Dune::TestSuite test("index size consistency check");
136 
137   auto prefix = typename Basis::SizePrefix{};
138 
139   for(const auto& index : multiIndexSet)
140   {
141     prefix.clear();
142     for (const auto& i: index)
143     {
144       // All indices i collected so far from the multi-index
145       // refer to a non-empty multi-index subtree. Hence the
146       // size must be nonzero and in fact strictly larger than
147       // the next index.
148       auto prefixSize = basis.size(prefix);
149       test.require(prefixSize > i, "basis.size(prefix) subtree check")
150         << "basis.size(" << prefix << ")=" << prefixSize << " but index " << index << " exists";
151 
152       // append next index from multi-index
153       prefix.push_back(i);
154     }
155     auto prefixSize = basis.size(prefix);
156     test.require(prefixSize == 0, "basis.size(prefix) leaf check")
157       << "basis.size(" << prefix << ")=" << prefixSize << " but the prefix exists as index";
158   }
159 
160   // ToDo: Add check that for basis.size(prefix)==n with i>0
161   // there exist multi-indices of the form (prefix,0,...)...(prefix,n-1,...)
162 
163   return test;
164 }
165 
166 
167 
168 /*
169  * Check indices of basis:
170  * - First store the whole index tree in a set
171  * - Check if this corresponds to a consistent index tree
172  * - Check if index tree is consistent with basis.size(prefix) and basis.dimension()
173  */
174 template<class Basis>
checkBasisIndices(const Basis & basis)175 Dune::TestSuite checkBasisIndices(const Basis& basis)
176 {
177   Dune::TestSuite test("basis index check");
178 
179   using MultiIndex = typename Basis::MultiIndex;
180 
181   static_assert(Dune::IsIndexable<MultiIndex>(), "MultiIndex must support operator[]");
182 
183   auto compare = [](const auto& a, const auto& b) {
184     return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
185   };
186 
187   auto multiIndexSet = std::set<MultiIndex, decltype(compare)>{compare};
188 
189   auto localView = basis.localView();
190   for (const auto& e : elements(basis.gridView()))
191   {
192     localView.bind(e);
193 
194     test.require(localView.size() <= localView.maxSize(), "localView.size() check")
195       << "localView.size() is " << localView.size() << " but localView.maxSize() is " << localView.maxSize();
196 
197     for (decltype(localView.size()) i=0; i< localView.size(); ++i)
198     {
199       auto multiIndex = localView.index(i);
200       for(auto mi: multiIndex)
201         test.check(mi>=0)
202           << "Global multi-index containes negative entry for shape function " << i
203           << " in element " << elementStr(localView.element(), basis.gridView());
204       multiIndexSet.insert(multiIndex);
205     }
206   }
207 
208   test.subTest(checkBasisIndexTreeConsistency(multiIndexSet));
209   test.subTest(checkBasisSizeConsistency(basis, multiIndexSet));
210   test.check(basis.dimension() == multiIndexSet.size())
211     << "basis.dimension() does not equal the total number of basis functions.";
212 
213   return test;
214 }
215 
216 
217 
218 /*
219  * Check if shape functions are not constant zero.
220  * This is called by checkLocalView().
221  */
222 template<class LocalFiniteElement>
checkNonZeroShapeFunctions(const LocalFiniteElement & fe,std::size_t order=5,double tol=1e-10)223 Dune::TestSuite checkNonZeroShapeFunctions(const LocalFiniteElement& fe, std::size_t order = 5, double tol = 1e-10)
224 {
225   Dune::TestSuite test;
226   static const int dimension = LocalFiniteElement::Traits::LocalBasisType::Traits::dimDomain;
227 
228   auto quadRule = Dune::QuadratureRules<double, dimension>::rule(fe.type(), order);
229 
230   std::vector<typename LocalFiniteElement::Traits::LocalBasisType::Traits::RangeType> values;
231   std::vector<bool> isNonZero;
232   isNonZero.resize(fe.size(), false);
233   for (const auto& qp : quadRule)
234   {
235     fe.localBasis().evaluateFunction(qp.position(), values);
236     for(std::size_t i=0; i<fe.size(); ++i)
237       isNonZero[i] = (isNonZero[i] or (values[i].infinity_norm() > tol));
238   }
239   for(std::size_t i=0; i<fe.size(); ++i)
240     test.check(isNonZero[i])
241       << "Found a constant zero basis function";
242   return test;
243 }
244 
245 
246 
247 /*
248  * Check localView. This especially checks for
249  * consistency of local indices and local size.
250  */
251 template<class Basis, class LocalView, class... Flags>
checkLocalView(const Basis & basis,const LocalView & localView,Flags...flags)252 Dune::TestSuite checkLocalView(const Basis& basis, const LocalView& localView, Flags... flags)
253 {
254   Dune::TestSuite test(std::string("LocalView on ") + elementStr(localView.element(), basis.gridView()));
255 
256   test.check(localView.size() <= localView.maxSize(), "localView.size() check")
257     << "localView.size() is " << localView.size() << " but localView.maxSize() is " << localView.maxSize();
258 
259   // Count all local indices appearing in the tree.
260   std::vector<std::size_t> localIndices;
261   localIndices.resize(localView.size(), 0);
262   Dune::TypeTree::forEachLeafNode(localView.tree(), [&](const auto& node, auto&& treePath) {
263     test.check(node.size() == node.finiteElement().size())
264       << "Size of leaf node and finite element are different.";
265     for(std::size_t i=0; i<node.size(); ++i)
266     {
267       test.check(node.localIndex(i) < localView.size())
268         << "Local index exceeds localView.size().";
269       if (node.localIndex(i) < localView.size())
270         ++(localIndices[node.localIndex(i)]);
271     }
272   });
273 
274   // Check if each local index appears exactly once.
275   for(std::size_t i=0; i<localView.size(); ++i)
276   {
277     if (localIndices[i])
278     test.check(localIndices[i]>=1)
279       << "Local index " << i << " did not appear";
280     test.check(localIndices[i]<=1)
281       << "Local index " << i << " appears multiple times";
282   }
283 
284   // Check if all basis functions are non-constant.
285   if (not IsContained<AllowZeroBasisFunctions, Flags...>::value)
286   {
287     Dune::TypeTree::forEachLeafNode(localView.tree(), [&](const auto& node, auto&& treePath) {
288       test.subTest(checkNonZeroShapeFunctions(node.finiteElement()));
289     });
290   }
291 
292   return test;
293 }
294 
295 
296 // Flag to enable a local continuity check for checking strong
297 // continuity across an intersection within checkBasisContinuity().
298 //
299 // For each inside basis function this will compute the jump against
300 // zero or the corresponding inside basis function. The latter is then
301 // checked for being (up to a tolerance) zero on a set of quadrature points.
302 struct EnableContinuityCheck
303 {
304   std::size_t order_ = 5;
305   double tol_ = 1e-10;
306 
307   template<class JumpEvaluator>
localJumpContinuityCheckEnableContinuityCheck308   auto localJumpContinuityCheck(const JumpEvaluator& jumpEvaluator, std::size_t order, double tol) const
309   {
310     return [=](const auto& intersection, const auto& treePath, const auto& insideNode, const auto& outsideNode, const auto& insideToOutside) {
311       using Intersection = std::decay_t<decltype(intersection)>;
312       using Node = std::decay_t<decltype(insideNode)>;
313 
314       std::vector<int> isContinuous(insideNode.size(), true);
315       const auto& quadRule = Dune::QuadratureRules<double, Intersection::mydimension>::rule(intersection.type(), order);
316 
317       using Range = typename Node::FiniteElement::Traits::LocalBasisType::Traits::RangeType;
318       std::vector<std::vector<Range>> values;
319       std::vector<std::vector<Range>> neighborValues;
320 
321       // Evaluate inside and outside basis functions.
322       values.resize(quadRule.size());
323       neighborValues.resize(quadRule.size());
324       for(std::size_t k=0; k<quadRule.size(); ++k)
325       {
326         auto pointInElement = intersection.geometryInInside().global(quadRule[k].position());
327         auto pointInNeighbor = intersection.geometryInOutside().global(quadRule[k].position());
328         insideNode.finiteElement().localBasis().evaluateFunction(pointInElement, values[k]);
329         outsideNode.finiteElement().localBasis().evaluateFunction(pointInNeighbor, neighborValues[k]);
330       }
331 
332       // Check jump against outside basis function or zero.
333       for(std::size_t i=0; i<insideNode.size(); ++i)
334       {
335         for(std::size_t k=0; k<quadRule.size(); ++k)
336         {
337           auto jump = values[k][i];
338           if (insideToOutside[i].has_value())
339             jump -= neighborValues[k][insideToOutside[i].value()];
340           isContinuous[i] = isContinuous[i] and (jumpEvaluator(jump, intersection, quadRule[k].position()) < tol);
341         }
342       }
343       return isContinuous;
344     };
345   }
346 
localContinuityCheckEnableContinuityCheck347   auto localContinuityCheck() const {
348     auto jumpNorm = [](auto&&jump, auto&& intersection, auto&& x) -> double {
349       return jump.infinity_norm();
350     };
351     return localJumpContinuityCheck(jumpNorm, order_, tol_);
352   }
353 };
354 
355 // Flag to enable a local normal-continuity check for checking strong
356 // continuity across an intersection within checkBasisContinuity().
357 //
358 // For each inside basis function this will compute the normal jump against
359 // zero or the corresponding inside basis function. The latter is then
360 // checked for being (up to a tolerance) zero on a set of quadrature points.
361 struct EnableNormalContinuityCheck : public EnableContinuityCheck
362 {
localContinuityCheckEnableNormalContinuityCheck363   auto localContinuityCheck() const {
364     auto normalJump = [](auto&&jump, auto&& intersection, auto&& x) -> double {
365       return jump * intersection.unitOuterNormal(x);
366     };
367     return localJumpContinuityCheck(normalJump, order_, tol_);
368   }
369 };
370 
371 // Flag to enable a local tangential-continuity check for checking continuity
372 // of tangential parts of a vector-valued basis across an intersection
373 // within checkBasisContinuity().
374 //
375 // For each inside basis function this will compute the tangential jump against
376 // zero or the corresponding outside basis function. The jump is then
377 // checked for being (up to a tolerance) zero on a set of quadrature points.
378 struct EnableTangentialContinuityCheck : public EnableContinuityCheck
379 {
localContinuityCheckEnableTangentialContinuityCheck380   auto localContinuityCheck() const {
381     auto tangentialJumpNorm = [](auto&&jump, auto&& intersection, auto&& x) -> double {
382       auto tangentialJump = jump - (jump * intersection.unitOuterNormal(x)) * intersection.unitOuterNormal(x);
383       return tangentialJump.two_norm();
384     };
385     return localJumpContinuityCheck(tangentialJumpNorm, order_, tol_);
386   }
387 };
388 
389 // Flag to enable a center continuity check for checking continuity in the
390 // center of an intersection within checkBasisContinuity().
391 //
392 // For each inside basis function this will compute the jump against
393 // zero or the corresponding inside basis function. The latter is then
394 // checked for being (up to a tolerance) zero in the center of mass
395 // of the intersection.
396 struct EnableCenterContinuityCheck : public EnableContinuityCheck
397 {
398   template<class JumpEvaluator>
localJumpCenterContinuityCheckEnableCenterContinuityCheck399   auto localJumpCenterContinuityCheck(const JumpEvaluator& jumpEvaluator, double tol) const
400   {
401     return [=](const auto& intersection, const auto& treePath, const auto& insideNode, const auto& outsideNode, const auto& insideToOutside) {
402       using Node = std::decay_t<decltype(insideNode)>;
403       using Range = typename Node::FiniteElement::Traits::LocalBasisType::Traits::RangeType;
404 
405       std::vector<int> isContinuous(insideNode.size(), true);
406       std::vector<Range> insideValues;
407       std::vector<Range> outsideValues;
408 
409       insideNode.finiteElement().localBasis().evaluateFunction(intersection.geometryInInside().center(), insideValues);
410       outsideNode.finiteElement().localBasis().evaluateFunction(intersection.geometryInOutside().center(), outsideValues);
411 
412       auto centerLocal = intersection.geometry().local(intersection.geometry().center());
413 
414       // Check jump against outside basis function or zero.
415       for(std::size_t i=0; i<insideNode.size(); ++i)
416       {
417           auto jump = insideValues[i];
418           if (insideToOutside[i].has_value())
419             jump -= outsideValues[insideToOutside[i].value()];
420           isContinuous[i] = isContinuous[i] and (jumpEvaluator(jump, intersection, centerLocal) < tol);
421       }
422       return isContinuous;
423     };
424   }
425 
localContinuityCheckEnableCenterContinuityCheck426   auto localContinuityCheck() const {
427     auto jumpNorm = [](auto&&jump, auto&& intersection, auto&& x) -> double {
428       return jump.infinity_norm();
429     };
430     return localJumpCenterContinuityCheck(jumpNorm, tol_);
431   }
432 };
433 
434 
435 /*
436  * Check if basis functions are continuous across faces.
437  * Continuity is checked by evaluation at a set of quadrature points
438  * from a quadrature rule of given order.
439  * If two basis functions (on neighboring elements) share the same
440  * global index, their values at the quadrature points (located on
441  * their intersection) should coincide up to the given tolerance.
442  *
443  * If a basis function only appears on one side of the intersection,
444  * it should be zero on the intersection.
445  */
446 template<class Basis, class LocalCheck>
checkBasisContinuity(const Basis & basis,const LocalCheck & localCheck)447 Dune::TestSuite checkBasisContinuity(const Basis& basis, const LocalCheck& localCheck)
448 {
449   Dune::TestSuite test("Global continuity check of basis functions");
450 
451 
452   auto localView = basis.localView();
453   auto neighborLocalView = basis.localView();
454 
455   for (const auto& e : elements(basis.gridView()))
456   {
457     localView.bind(e);
458     for(const auto& intersection : intersections(basis.gridView(), e))
459     {
460       if (intersection.neighbor())
461       {
462         neighborLocalView.bind(intersection.outside());
463 
464         Dune::TypeTree::forEachLeafNode(localView.tree(), [&](const auto& insideNode, auto&& treePath) {
465           const auto& outsideNode = Dune::TypeTree::child(neighborLocalView.tree(), treePath);
466 
467           std::vector<std::optional<int>> insideToOutside;
468           insideToOutside.resize(insideNode.size());
469 
470           // Map all inside DOFs to outside DOFs if possible
471           for(std::size_t i=0; i<insideNode.size(); ++i)
472           {
473             for(std::size_t j=0; j<outsideNode.size(); ++j)
474             {
475               if (localView.index(insideNode.localIndex(i)) == neighborLocalView.index(outsideNode.localIndex(j)))
476               {
477                 // Basis function should only appear once in the neighbor element.
478                 test.check(not insideToOutside[i].has_value())
479                   << "Basis function " << localView.index(insideNode.localIndex(i))
480                   << " appears twice in element " << elementStr(neighborLocalView.element(), basis.gridView());
481                 insideToOutside[i] = j;
482               }
483             }
484           }
485 
486           // Apply continuity check on given intersection with given inside/outside DOF node pair.
487           auto isContinuous = localCheck(intersection, treePath, insideNode, outsideNode, insideToOutside);
488 
489           for(std::size_t i=0; i<insideNode.size(); ++i)
490           {
491             test.check(isContinuous[i])
492               << "Basis function " << localView.index(insideNode.localIndex(i))
493               << " is discontinuous across intersection of elements "
494               << elementStr(localView.element(), basis.gridView())
495               << " and " << elementStr(neighborLocalView.element(), basis.gridView());
496           }
497         });
498       }
499     }
500   }
501   return test;
502 }
503 
504 template<class Basis, class... Flags>
checkConstBasis(const Basis & basis,Flags...flags)505 Dune::TestSuite checkConstBasis(const Basis& basis, Flags... flags)
506 {
507   Dune::TestSuite test("const basis check");
508 
509   using GridView = typename Basis::GridView;
510 
511   // Check if basis models the GlobalBasis concept.
512   test.check(Dune::models<Dune::Functions::Concept::GlobalBasis<GridView>, Basis>(), "global basis concept check")
513     << "type passed to checkBasis() does not model the GlobalBasis concept";
514 
515   // Perform all local tests.
516   auto localView = basis.localView();
517   for (const auto& e : elements(basis.gridView()))
518   {
519     localView.bind(e);
520     test.subTest(checkLocalView(basis, localView, flags...));
521   }
522 
523   // Perform global index tests.
524   test.subTest(checkBasisIndices(basis));
525 
526   // Perform continuity check.
527   // First capture flags in a tuple in order to iterate.
528   auto flagTuple = std::tie(flags...);
529   Dune::Hybrid::forEach(flagTuple, [&](auto&& flag) {
530     using Flag = std::decay_t<decltype(flag)>;
531     if constexpr (std::is_base_of_v<EnableContinuityCheck, Flag>)
532       test.subTest(checkBasisContinuity(basis, flag.localContinuityCheck()));
533   });
534 
535   return test;
536 }
537 
538 
539 template<class Basis, class... Flags>
checkBasis(Basis & basis,Flags...flags)540 Dune::TestSuite checkBasis(Basis& basis, Flags... flags)
541 {
542   Dune::TestSuite test("basis check");
543 
544   // Perform tests for a constant basis
545   test.subTest(checkConstBasis(basis,flags...));
546 
547   // Check update of gridView
548   auto gridView = basis.gridView();
549   basis.update(gridView);
550 
551   return test;
552 }
553 
554 
555 
556 
557 #endif // DUNE_FUNCTIONS_FUNCTIONSPACEBASES_TEST_BASISTEST_HH
558