1 // Copyright (c) 2010-2021, Lawrence Livermore National Security, LLC. Produced
2 // at the Lawrence Livermore National Laboratory. All Rights reserved. See files
3 // LICENSE and NOTICE for details. LLNL-CODE-806117.
4 //
5 // This file is part of the MFEM library. For more information and source code
6 // availability visit https://mfem.org.
7 //
8 // MFEM is free software; you can redistribute it and/or modify it under the
9 // terms of the BSD-3 license. We welcome feedback and contributions, see file
10 // CONTRIBUTING.md for details.
11 
12 #include "mfem.hpp"
13 #include "unit_tests.hpp"
14 #include "general/tinyxml2.h"
15 #include <stdio.h>
16 
17 #ifndef _WIN32
18 #include <unistd.h> // rmdir
19 #else
20 #include <direct.h> // _rmdir
21 #define rmdir(dir) _rmdir(dir)
22 #endif
23 
24 using namespace mfem;
25 
26 TEST_CASE("Save and load from collections", "[DataCollection]")
27 {
28    SECTION("VisIt data files")
29    {
30       std::cout<<"Testing VisIt data files"<<std::endl;
31       // Set up a small mesh and a couple of grid function on that mesh
32       Mesh mesh = Mesh::MakeCartesian2D(2, 3, Element::QUADRILATERAL, 0, 2.0, 3.0);
33       FiniteElementCollection *fec = new LinearFECollection;
34       FiniteElementSpace *fespace = new FiniteElementSpace(&mesh, fec);
35       GridFunction *u = new GridFunction(fespace);
36       GridFunction *v = new GridFunction(fespace);
37 
38       int N = u->Size();
39       for (int i = 0; i < N; ++i)
40       {
41          (*u)(i) = double(i);
42          (*v)(i) = double(N - i - 1);
43       }
44 
45       int intOrder = 3;
46 
47       QuadratureSpace *qspace = new QuadratureSpace(&mesh, intOrder);
48       QuadratureFunction *qs = new QuadratureFunction(qspace, 1);
49       QuadratureFunction *qv = new QuadratureFunction(qspace, 2);
50 
51       int Nq = qs->Size();
52       for (int i = 0; i < Nq; ++i)
53       {
54          (*qs)(i) = double(i);
55          (*qv)(2*i+0) = double(i);
56          (*qv)(2*i+1) = double(Nq - i - 1);
57       }
58 
59 
60       SECTION("Uncompressed MFEM format")
61       {
62          std::cout<<"Testing uncompressed MFEM format"<<std::endl;
63 
64          // Collect the mesh and grid functions into a DataCollection and test that they got in there
65          VisItDataCollection dc("base", &mesh);
66          dc.RegisterField("u", u);
67          dc.RegisterField("v", v);
68          dc.RegisterQField("qs",qs);
69          dc.RegisterQField("qv",qv);
70          dc.SetCycle(5);
71          dc.SetTime(8.0);
72          REQUIRE(dc.GetMesh() == &mesh);
73          bool has_u = dc.HasField("u");
74          REQUIRE(has_u);
75          bool has_v = dc.HasField("v");
76          REQUIRE(has_v);
77          bool has_qs = dc.HasQField("qs");
78          REQUIRE(has_qs);
79          bool has_qv = dc.HasQField("qv");
80          REQUIRE(has_qv);
81          REQUIRE(dc.GetCycle() == 5);
82          REQUIRE(dc.GetTime() == 8.0);
83 
84          // Save the DataCollection and load it into a new DataCollection for comparison
85          dc.SetPadDigits(5);
86          dc.Save();
87 
88          VisItDataCollection dc_new("base");
89          dc_new.SetPadDigits(5);
90          dc_new.Load(dc.GetCycle());
91          Mesh* mesh_new = dc_new.GetMesh();
92          GridFunction *u_new = dc_new.GetField("u");
93          GridFunction *v_new = dc_new.GetField("v");
94          QuadratureFunction *qs_new = dc_new.GetQField("qs");
95          QuadratureFunction *qv_new = dc_new.GetQField("qv");
96          REQUIRE(mesh_new);
97          REQUIRE(u_new);
98          REQUIRE(v_new);
99          REQUIRE(qs_new);
100          REQUIRE(qv_new);
101 
102          // Compare some collection parameters for old and new
103          std::string name, name_new;
104          name = dc.GetCollectionName();
105          name_new = dc_new.GetCollectionName();
106          REQUIRE(name == name_new);
107          REQUIRE(dc.GetCycle() == dc_new.GetCycle());
108          REQUIRE(dc.GetTime() == dc_new.GetTime());
109 
110          // Compare the new mesh with the old mesh
111          // (Just a basic comparison here, a full comparison should be done in Mesh unit testing)
112          REQUIRE(mesh.Dimension() == mesh_new->Dimension());
113          REQUIRE(mesh.SpaceDimension() == mesh_new->SpaceDimension());
114 
115          Vector vert, vert_diff;
116          mesh.GetVertices(vert);
117          mesh_new->GetVertices(vert_diff);
118          vert_diff -= vert;
119          REQUIRE(vert_diff.Normlinf() < 1e-10);
120 
121          // Compare the old and new grid functions
122          // (Just a basic comparison here, a full comparison should be done in GridFunction unit testing)
123          Vector u_diff(*u_new), v_diff(*v_new);
124          u_diff -= *u;
125          v_diff -= *v;
126          REQUIRE(u_diff.Normlinf() < 1e-10);
127          REQUIRE(v_diff.Normlinf() < 1e-10);
128 
129          // Compare the old and new quadrature functions
130          // (Just a basic comparison here, a full comparison should be done in GridFunction unit testing)
131          Vector qs_diff(*qs_new), qv_diff(*qv_new);
132          qs_diff -= *qs;
133          qv_diff -= *qv;
134          REQUIRE(qs_diff.Normlinf() < 1e-10);
135          REQUIRE(qv_diff.Normlinf() < 1e-10);
136 
137          // Cleanup all the files
138          REQUIRE(remove("base_00005.mfem_root") == 0);
139          REQUIRE(remove("base_00005/mesh.00000") == 0);
140          REQUIRE(remove("base_00005/u.00000") == 0);
141          REQUIRE(remove("base_00005/v.00000") == 0);
142          REQUIRE(remove("base_00005/qs.00000") == 0);
143          REQUIRE(remove("base_00005/qv.00000") == 0);
144          REQUIRE(rmdir("base_00005") == 0);
145       }
146 
147 #ifdef MFEM_USE_ZLIB
148       SECTION("Compressed MFEM format")
149       {
150          std::cout<<"Testing compressed MFEM format"<<std::endl;
151 
152          // Collect the mesh and grid functions into a DataCollection and test that they got in there
153          VisItDataCollection dc("base", &mesh);
154          dc.RegisterField("u", u);
155          dc.RegisterField("v", v);
156          dc.RegisterQField("qs",qs);
157          dc.RegisterQField("qv",qv);
158          dc.SetCycle(5);
159          dc.SetTime(8.0);
160          REQUIRE(dc.GetMesh() == &mesh);
161          bool has_u = dc.HasField("u");
162          REQUIRE(has_u);
163          bool has_v = dc.HasField("v");
164          REQUIRE(has_v);
165          bool has_qs = dc.HasQField("qs");
166          REQUIRE(has_qs);
167          bool has_qv = dc.HasQField("qv");
168          REQUIRE(has_qv);
169          REQUIRE(dc.GetCycle() == 5);
170          REQUIRE(dc.GetTime() == 8.0);
171 
172          // Save the DataCollection and load it into a new DataCollection for comparison
173          dc.SetPadDigits(5);
174          dc.SetCompression(true);
175          dc.Save();
176 
177          VisItDataCollection dc_new("base");
178          dc_new.SetPadDigits(5);
179          dc_new.Load(dc.GetCycle());
180          Mesh *mesh_new = dc_new.GetMesh();
181          GridFunction *u_new = dc_new.GetField("u");
182          GridFunction *v_new = dc_new.GetField("v");
183          QuadratureFunction *qs_new = dc_new.GetQField("qs");
184          QuadratureFunction *qv_new = dc_new.GetQField("qv");
185          REQUIRE(mesh_new);
186          REQUIRE(u_new);
187          REQUIRE(v_new);
188          REQUIRE(qs_new);
189          REQUIRE(qv_new);
190 
191          // Compare some collection parameters for old and new
192          std::string name, name_new;
193          name = dc.GetCollectionName();
194          name_new = dc_new.GetCollectionName();
195          REQUIRE(name == name_new);
196          REQUIRE(dc.GetCycle() == dc_new.GetCycle());
197          REQUIRE(dc.GetTime() == dc_new.GetTime());
198 
199          // Compare the new mesh with the old mesh
200          // (Just a basic comparison here, a full comparison should be done in Mesh unit testing)
201          REQUIRE(mesh.Dimension() == mesh_new->Dimension());
202          REQUIRE(mesh.SpaceDimension() == mesh_new->SpaceDimension());
203 
204          Vector vert, vert_diff;
205          mesh.GetVertices(vert);
206          mesh_new->GetVertices(vert_diff);
207          vert_diff -= vert;
208          REQUIRE(vert_diff.Normlinf() < 1e-10);
209 
210          // Compare the old and new grid functions
211          // (Just a basic comparison here, a full comparison should be done in GridFunction unit testing)
212          Vector u_diff(*u_new), v_diff(*v_new);
213          u_diff -= *u;
214          v_diff -= *v;
215          REQUIRE(u_diff.Normlinf() < 1e-10);
216          REQUIRE(v_diff.Normlinf() < 1e-10);
217 
218          // Compare the old and new quadrature functions
219          // (Just a basic comparison here, a full comparison should be done in GridFunction unit testing)
220          Vector qs_diff(*qs_new), qv_diff(*qv_new);
221          qs_diff -= *qs;
222          qv_diff -= *qv;
223          REQUIRE(qs_diff.Normlinf() < 1e-10);
224          REQUIRE(qv_diff.Normlinf() < 1e-10);
225 
226          // Cleanup all the files
227          REQUIRE(remove("base_00005.mfem_root") == 0);
228          REQUIRE(remove("base_00005/mesh.00000") == 0);
229          REQUIRE(remove("base_00005/u.00000") == 0);
230          REQUIRE(remove("base_00005/v.00000") == 0);
231          REQUIRE(remove("base_00005/qs.00000") == 0);
232          REQUIRE(remove("base_00005/qv.00000") == 0);
233          REQUIRE(rmdir("base_00005") == 0);
234       }
235 #endif
236    }
237 
238 }
239 
SaveDataCollection(DataCollection & dc,int cycle,double t)240 void SaveDataCollection(DataCollection &dc, int cycle, double t)
241 {
242    dc.SetCycle(cycle);
243    dc.SetTime(t);
244    dc.Save();
245 }
246 
247 TEST_CASE("ParaView restart mode", "[ParaView]")
248 {
249    Mesh mesh = Mesh::MakeCartesian2D(2, 3, Element::QUADRILATERAL);
250    H1_FECollection fec(1, mesh.Dimension());
251    FiniteElementSpace fes(&mesh, &fec);
252    GridFunction u(&fes);
253    u = 0.0;
254 
255    // Write initial dataset with three timesteps: 0, 1, 2.
256    {
257       ParaViewDataCollection dc("ParaView", &mesh);
258       dc.RegisterField("u", &u);
259       SaveDataCollection(dc, 0, 0);
260       SaveDataCollection(dc, 1, 1);
261       SaveDataCollection(dc, 2, 2);
262    }
263 
264    // Using restart mode, append to the existing dataset, overwriting timesteps
265    // 1 and 2 with 1 and 1.5.
266    {
267       ParaViewDataCollection dc("ParaView", &mesh);
268       dc.UseRestartMode(true);
269       dc.RegisterField("u", &u);
270       SaveDataCollection(dc, 1, 1.0);
271       SaveDataCollection(dc, 2, 1.5);
272    }
273 
274    // Parse the resulting PVD file, and verify that the structure is correct,
275    // and that it contains three timesteps: 0, 1, and 1.5.
276    using namespace tinyxml2;
277    auto StringCompare = [](const char *s1, const char *s2)
__anonc120cefb0102(const char *s1, const char *s2) 278    {
279       if (s1 == NULL || s2 == NULL) { return false; }
280       return strcmp(s1, s2) == 0;
281    };
282    auto VerifyDataset = [StringCompare](const XMLElement *ds, double t_ref)
__anonc120cefb0202(const XMLElement *ds, double t_ref) 283    {
284       REQUIRE(ds);
285       REQUIRE(StringCompare(ds->Name(), "DataSet"));
286       const char *timestep = ds->Attribute("timestep");
287       REQUIRE(timestep);
288       double t = std::stod(timestep);
289       REQUIRE(t == MFEM_Approx(t_ref));
290    };
291 
292    XMLDocument xml;
293    xml.LoadFile("ParaView/ParaView.pvd");
294    REQUIRE(xml.ErrorID() == XML_SUCCESS);
295 
296    const XMLElement *vtkfile = xml.FirstChildElement();
297    REQUIRE(vtkfile);
298    REQUIRE(StringCompare(vtkfile->Name(), "VTKFile"));
299    const XMLElement *collection = vtkfile->FirstChildElement();
300    REQUIRE(collection);
301    REQUIRE(StringCompare(collection->Name(), "Collection"));
302 
303    const XMLElement *dataset = collection->FirstChildElement();
304    VerifyDataset(dataset, 0.0);
305    dataset = dataset->NextSiblingElement();
306    VerifyDataset(dataset, 1.0);
307    dataset = dataset->NextSiblingElement();
308    VerifyDataset(dataset, 1.5);
309    REQUIRE(dataset->NextSiblingElement() == NULL);
310 
311    // Clean up
312    for (int c=0; c<=2; ++c)
313    {
314       std::string prefix = "ParaView/Cycle00000" + std::to_string(c);
315       REQUIRE(remove((prefix + "/data.pvtu").c_str()) == 0);
316       REQUIRE(remove((prefix + "/proc000000.vtu").c_str()) == 0);
317       REQUIRE(rmdir(prefix.c_str()) == 0);
318    }
319    REQUIRE(remove("ParaView/ParaView.pvd") == 0);
320    REQUIRE(rmdir("ParaView") == 0);
321 }
322